diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 00000000..6413432f --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1 @@ +* @fabianfett @gwynne diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 49d2cef1..1eaf9a87 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -18,14 +18,11 @@ jobs: fail-fast: false matrix: swift-image: - - swift:5.7-jammy - - swift:5.8-jammy - - swift:5.9-jammy - swift:5.10-jammy - - swiftlang/swift:nightly-main-jammy - include: - - swift-image: swift:5.10-jammy - code-coverage: true + - swift:6.0-noble + - swift:6.1-noble + - swiftlang/swift:nightly-6.2-noble + - swiftlang/swift:nightly-main-noble container: ${{ matrix.swift-image }} runs-on: ubuntu-latest steps: @@ -36,34 +33,38 @@ jobs: [[ -z "${SWIFT_VERSION}" ]] && SWIFT_VERSION="$(cat /.swift_tag 2>/dev/null || true)" printf 'OS: %s\nTag: %s\nVersion:\n' "${SWIFT_PLATFORM}-${RUNNER_ARCH}" "${SWIFT_VERSION}" swift --version + - name: Install curl for Codecov + run: apt-get update -y -q && apt-get install -y curl - name: Check out package - uses: actions/checkout@v4 + uses: actions/checkout@v5 - name: Run unit tests with Thread Sanitizer - env: - CODE_COVERAGE: ${{ matrix.code-coverage && '--enable-code-coverage' || '' }} + shell: bash run: | - swift test --filter='^(PostgresNIOTests|ConnectionPoolModuleTests)' --sanitize=thread ${CODE_COVERAGE} + # https://github.com/swiftlang/swift/issues/74042 was never fixed in 5.10 and swift-crypto hits it in 6.0 as well + SANITIZE="$([[ "${SWIFT_VERSION}" =~ ^swift-(5|6\.0) ]] || echo '--sanitize=thread')" + swift test --filter='^(PostgresNIOTests|ConnectionPoolModuleTests)' ${SANITIZE} --enable-code-coverage - name: Submit code coverage - if: ${{ matrix.code-coverage }} - uses: vapor/swift-codecov-action@v0.2 + uses: vapor/swift-codecov-action@v0.3 + with: + codecov_token: ${{ secrets.CODECOV_TOKEN }} linux-integration-and-dependencies: strategy: fail-fast: false matrix: postgres-image: - - postgres:16 - - postgres:14 - - postgres:12 + - postgres:17 + - postgres:15 + - postgres:13 include: - - postgres-image: postgres:16 + - postgres-image: postgres:17 postgres-auth: scram-sha-256 - - postgres-image: postgres:14 + - postgres-image: postgres:15 postgres-auth: md5 - - postgres-image: postgres:12 + - postgres-image: postgres:13 postgres-auth: trust container: - image: swift:5.10-jammy + image: swift:6.1-noble volumes: [ 'pgrunshare:/var/run/postgresql' ] runs-on: ubuntu-latest env: @@ -108,15 +109,15 @@ jobs: [[ -z "${SWIFT_VERSION}" ]] && SWIFT_VERSION="$(cat /.swift_tag 2>/dev/null || true)" printf 'OS: %s\nTag: %s\nVersion:\n' "${SWIFT_PLATFORM}-${RUNNER_ARCH}" "${SWIFT_VERSION}" && swift --version - name: Check out package - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: { path: 'postgres-nio' } - name: Run integration tests run: swift test --package-path postgres-nio --filter=^IntegrationTests - name: Check out postgres-kit dependent - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: { repository: 'vapor/postgres-kit', path: 'postgres-kit' } - name: Check out fluent-postgres-driver dependent - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: { repository: 'vapor/fluent-postgres-driver', path: 'fluent-postgres-driver' } - name: Use local package in dependents run: | @@ -137,14 +138,14 @@ jobs: postgres-auth: # Only test one auth method on macOS, Linux tests will cover the others - scram-sha-256 - xcode-version: - - '~14.3' - - '~15.0' + macos-version: + - 'macos-14' + - 'macos-15' include: - - xcode-version: '~14.3' - macos-version: 'macos-13' - - xcode-version: '~15.0' - macos-version: 'macos-14' + - macos-version: 'macos-14' + xcode-version: 'latest-stable' + - macos-version: 'macos-15' + xcode-version: 'latest-stable' runs-on: ${{ matrix.macos-version }} env: POSTGRES_HOSTNAME: 127.0.0.1 @@ -162,28 +163,23 @@ jobs: - name: Install Postgres, setup DB and auth, and wait for server start run: | export PATH="$(brew --prefix)/opt/${POSTGRES_FORMULA}/bin:$PATH" PGDATA=/tmp/vapor-postgres-test - # ** BEGIN ** Work around bug in both Homebrew and GHA - (brew upgrade python@3.11 || true) && (brew link --force --overwrite python@3.11 || true) - (brew upgrade python@3.12 || true) && (brew link --force --overwrite python@3.12 || true) - brew upgrade - # ** END ** Work around bug in both Homebrew and GHA brew install --overwrite "${POSTGRES_FORMULA}" brew link --overwrite --force "${POSTGRES_FORMULA}" initdb --locale=C --auth-host "${POSTGRES_AUTH_METHOD}" -U "${POSTGRES_USER}" --pwfile=<(echo "${POSTGRES_PASSWORD}") pg_ctl start --wait timeout-minutes: 15 - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v5 - name: Run all tests run: swift test - + api-breakage: if: github.event_name == 'pull_request' runs-on: ubuntu-latest - container: swift:jammy + container: swift:noble steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v5 with: fetch-depth: 0 # https://github.com/actions/checkout/issues/766 @@ -191,22 +187,3 @@ jobs: run: | git config --global --add safe.directory "${GITHUB_WORKSPACE}" swift package diagnose-api-breaking-changes origin/main - - gh-codeql: - if: ${{ false }} - runs-on: ubuntu-latest - container: swift:jammy - permissions: { actions: write, contents: read, security-events: write } - steps: - - name: Check out code - uses: actions/checkout@v4 - - name: Mark repo safe in non-fake global config - run: git config --global --add safe.directory "${GITHUB_WORKSPACE}" - - name: Initialize CodeQL - uses: github/codeql-action/init@v3 - with: - languages: swift - - name: Perform build - run: swift build - - name: Run CodeQL analyze - uses: github/codeql-action/analyze@v3 diff --git a/Benchmarks/.gitignore b/Benchmarks/.gitignore new file mode 100644 index 00000000..24e5b0a1 --- /dev/null +++ b/Benchmarks/.gitignore @@ -0,0 +1 @@ +.build diff --git a/Benchmarks/Benchmarks/ConnectionPoolBenchmarks/ConnectionPoolBenchmarks.swift b/Benchmarks/Benchmarks/ConnectionPoolBenchmarks/ConnectionPoolBenchmarks.swift new file mode 100644 index 00000000..9cc535d4 --- /dev/null +++ b/Benchmarks/Benchmarks/ConnectionPoolBenchmarks/ConnectionPoolBenchmarks.swift @@ -0,0 +1,99 @@ +import _ConnectionPoolModule +import _ConnectionPoolTestUtils +import Benchmark + +let benchmarks: @Sendable () -> Void = { + Benchmark("Lease/Release 1k requests: 50 parallel", configuration: .init(scalingFactor: .kilo)) { benchmark in + let clock = MockClock() + let factory = MockConnectionFactory(autoMaxStreams: 1) + var configuration = ConnectionPoolConfiguration() + configuration.maximumConnectionSoftLimit = 50 + configuration.maximumConnectionHardLimit = 50 + + let pool = ConnectionPool( + configuration: configuration, + idGenerator: ConnectionIDGenerator(), + keepAliveBehavior: MockPingPongBehavior(keepAliveFrequency: nil, connectionType: MockConnection.self), + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: clock + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + await withTaskGroup { taskGroup in + + taskGroup.addTask { + await pool.run() + } + + let sequential = benchmark.scaledIterations.upperBound / configuration.maximumConnectionSoftLimit + + benchmark.startMeasurement() + + for parallel in 0..(autoMaxStreams: 1) + var configuration = ConnectionPoolConfiguration() + configuration.maximumConnectionSoftLimit = 50 + configuration.maximumConnectionHardLimit = 50 + + let pool = ConnectionPool( + configuration: configuration, + idGenerator: ConnectionIDGenerator(), + keepAliveBehavior: MockPingPongBehavior(keepAliveFrequency: nil, connectionType: MockConnection.self), + observabilityDelegate: NoOpConnectionPoolMetrics(connectionIDType: MockConnection.ID.self), + clock: clock + ) { + try await factory.makeConnection(id: $0, for: $1) + } + + await withTaskGroup { taskGroup in + + taskGroup.addTask { + await pool.run() + } + + let sequential = benchmark.scaledIterations.upperBound / configuration.maximumConnectionSoftLimit + + benchmark.startMeasurement() + + for _ in benchmark.scaledIterations { + do { + try await pool.withConnection { connection in + blackHole(connection) + } + } catch { + fatalError() + } + } + + benchmark.stopMeasurement() + + taskGroup.cancelAll() + } + } +} diff --git a/Benchmarks/Package.swift b/Benchmarks/Package.swift new file mode 100644 index 00000000..11407176 --- /dev/null +++ b/Benchmarks/Package.swift @@ -0,0 +1,28 @@ +// swift-tools-version: 6.0 + +import PackageDescription + +let package = Package( + name: "benchmarks", + platforms: [ + .macOS("14") + ], + dependencies: [ + .package(path: "../"), + .package(url: "https://github.com/ordo-one/package-benchmark.git", from: "1.29.0"), + ], + targets: [ + .executableTarget( + name: "ConnectionPoolBenchmarks", + dependencies: [ + .product(name: "_ConnectionPoolModule", package: "postgres-nio"), + .product(name: "_ConnectionPoolTestUtils", package: "postgres-nio"), + .product(name: "Benchmark", package: "package-benchmark"), + ], + path: "Benchmarks/ConnectionPoolBenchmarks", + plugins: [ + .plugin(name: "BenchmarkPlugin", package: "package-benchmark") + ] + ), + ] +) diff --git a/Package.swift b/Package.swift index 4d008371..af2d07ae 100644 --- a/Package.swift +++ b/Package.swift @@ -1,6 +1,10 @@ -// swift-tools-version:5.7 +// swift-tools-version:5.10 import PackageDescription +let swiftSettings: [SwiftSetting] = [ + .enableUpcomingFeature("StrictConcurrency"), +] + let package = Package( name: "postgres-nio", platforms: [ @@ -12,17 +16,18 @@ let package = Package( products: [ .library(name: "PostgresNIO", targets: ["PostgresNIO"]), .library(name: "_ConnectionPoolModule", targets: ["_ConnectionPoolModule"]), + .library(name: "_ConnectionPoolTestUtils", targets: ["_ConnectionPoolTestUtils"]), ], dependencies: [ .package(url: "https://github.com/apple/swift-atomics.git", from: "1.2.0"), .package(url: "https://github.com/apple/swift-collections.git", from: "1.0.4"), - .package(url: "https://github.com/apple/swift-nio.git", from: "2.59.0"), + .package(url: "https://github.com/apple/swift-nio.git", from: "2.81.0"), .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.19.0"), .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.25.0"), - .package(url: "https://github.com/apple/swift-crypto.git", "2.0.0" ..< "4.0.0"), + .package(url: "https://github.com/apple/swift-crypto.git", "3.9.0" ..< "4.0.0"), .package(url: "https://github.com/apple/swift-metrics.git", from: "2.4.1"), .package(url: "https://github.com/apple/swift-log.git", from: "1.5.3"), - .package(url: "https://github.com/swift-server/swift-service-lifecycle.git", from: "2.4.1"), + .package(url: "https://github.com/swift-server/swift-service-lifecycle.git", from: "2.5.0"), ], targets: [ .target( @@ -31,6 +36,7 @@ let package = Package( .target(name: "_ConnectionPoolModule"), .product(name: "Atomics", package: "swift-atomics"), .product(name: "Crypto", package: "swift-crypto"), + .product(name: "_CryptoExtras", package: "swift-crypto"), .product(name: "Logging", package: "swift-log"), .product(name: "Metrics", package: "swift-metrics"), .product(name: "NIO", package: "swift-nio"), @@ -41,7 +47,8 @@ let package = Package( .product(name: "NIOSSL", package: "swift-nio-ssl"), .product(name: "NIOFoundationCompat", package: "swift-nio"), .product(name: "ServiceLifecycle", package: "swift-service-lifecycle"), - ] + ], + swiftSettings: swiftSettings ), .target( name: "_ConnectionPoolModule", @@ -49,7 +56,17 @@ let package = Package( .product(name: "Atomics", package: "swift-atomics"), .product(name: "DequeModule", package: "swift-collections"), ], - path: "Sources/ConnectionPoolModule" + path: "Sources/ConnectionPoolModule", + swiftSettings: swiftSettings + ), + .target( + name: "_ConnectionPoolTestUtils", + dependencies: [ + "_ConnectionPoolModule", + .product(name: "NIOConcurrencyHelpers", package: "swift-nio"), + ], + path: "Sources/ConnectionPoolTestUtils", + swiftSettings: swiftSettings ), .testTarget( name: "PostgresNIOTests", @@ -57,24 +74,28 @@ let package = Package( .target(name: "PostgresNIO"), .product(name: "NIOEmbedded", package: "swift-nio"), .product(name: "NIOTestUtils", package: "swift-nio"), - ] + ], + swiftSettings: swiftSettings ), .testTarget( name: "ConnectionPoolModuleTests", dependencies: [ .target(name: "_ConnectionPoolModule"), + .target(name: "_ConnectionPoolTestUtils"), .product(name: "DequeModule", package: "swift-collections"), .product(name: "NIOCore", package: "swift-nio"), .product(name: "NIOConcurrencyHelpers", package: "swift-nio"), .product(name: "NIOEmbedded", package: "swift-nio"), - ] + ], + swiftSettings: swiftSettings ), .testTarget( name: "IntegrationTests", dependencies: [ .target(name: "PostgresNIO"), .product(name: "NIOTestUtils", package: "swift-nio"), - ] + ], + swiftSettings: swiftSettings ), ] ) diff --git a/README.md b/README.md index 9e7d4e3b..6d03b8da 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,5 @@

- - - - PostgresNIO - +PostgresNIO

@@ -16,7 +12,7 @@ Continuous Integration - Swift 5.7+ + Swift 5.10+ SSWG Incubation Level: Graduated @@ -43,7 +39,7 @@ detailed look at all of the classes, structs, protocols, and more. ## Getting started -Interested in an example? We prepared a simple [Birthday example](/vapor/postgres-nio/tree/main/Snippets/Birthdays.swift) +Interested in an example? We prepared a simple [Birthday example](https://github.com/vapor/postgres-nio/blob/main/Snippets/Birthdays.swift) in the Snippets folder. #### Adding the dependency @@ -167,7 +163,7 @@ Please see [SECURITY.md] for details on the security process. [Team Chat]: https://discord.gg/vapor [MIT License]: LICENSE [Continuous Integration]: https://github.com/vapor/postgres-nio/actions -[Swift 5.7]: https://swift.org +[Swift 5.10]: https://swift.org [Security.md]: https://github.com/vapor/.github/blob/main/SECURITY.md [`PostgresConnection`]: https://api.vapor.codes/postgresnio/documentation/postgresnio/postgresconnection diff --git a/Sources/ConnectionPoolModule/ConnectionLease.swift b/Sources/ConnectionPoolModule/ConnectionLease.swift new file mode 100644 index 00000000..77591a58 --- /dev/null +++ b/Sources/ConnectionPoolModule/ConnectionLease.swift @@ -0,0 +1,17 @@ +public struct ConnectionLease: Sendable { + public var connection: Connection + + @usableFromInline + let _release: @Sendable (Connection) -> () + + @inlinable + public init(connection: Connection, release: @escaping @Sendable (Connection) -> Void) { + self.connection = connection + self._release = release + } + + @inlinable + public func release() { + self._release(self.connection) + } +} diff --git a/Sources/ConnectionPoolModule/ConnectionPool.swift b/Sources/ConnectionPoolModule/ConnectionPool.swift index 9f25e82c..ee72337d 100644 --- a/Sources/ConnectionPoolModule/ConnectionPool.swift +++ b/Sources/ConnectionPoolModule/ConnectionPool.swift @@ -1,6 +1,6 @@ @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) -public struct ConnectionAndMetadata { +public struct ConnectionAndMetadata: Sendable { public var connection: Connection @@ -88,7 +88,7 @@ public protocol ConnectionRequestProtocol: Sendable { /// A function that is called with a connection or a /// `PoolError`. - func complete(with: Result) + func complete(with: Result, ConnectionPoolError>) } @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) @@ -271,15 +271,14 @@ public final class ConnectionPool< } } + @inlinable public func run() async { await withTaskCancellationHandler { - #if swift(>=5.8) && os(Linux) || swift(>=5.9) if #available(macOS 14.0, iOS 17.0, tvOS 17.0, watchOS 10.0, *) { return await withDiscardingTaskGroup() { taskGroup in await self.run(in: &taskGroup) } } - #endif return await withTaskGroup(of: Void.self) { taskGroup in await self.run(in: &taskGroup) } @@ -313,16 +312,16 @@ public final class ConnectionPool< case scheduleTimer(StateMachine.Timer) } - #if swift(>=5.8) && os(Linux) || swift(>=5.9) @available(macOS 14.0, iOS 17.0, tvOS 17.0, watchOS 10.0, *) - private func run(in taskGroup: inout DiscardingTaskGroup) async { + @inlinable + /* private */ func run(in taskGroup: inout DiscardingTaskGroup) async { for await event in self.eventStream { self.runEvent(event, in: &taskGroup) } } - #endif - private func run(in taskGroup: inout TaskGroup) async { + @inlinable + /* private */ func run(in taskGroup: inout TaskGroup) async { var running = 0 for await event in self.eventStream { running += 1 @@ -335,7 +334,8 @@ public final class ConnectionPool< } } - private func runEvent(_ event: NewPoolActions, in taskGroup: inout some TaskGroupProtocol) { + @inlinable + /* private */ func runEvent(_ event: NewPoolActions, in taskGroup: inout some TaskGroupProtocol) { switch event { case .makeConnection(let request): self.makeConnection(for: request, in: &taskGroup) @@ -402,8 +402,11 @@ public final class ConnectionPool< /*private*/ func runRequestAction(_ action: StateMachine.RequestAction) { switch action { case .leaseConnection(let requests, let connection): + let lease = ConnectionLease(connection: connection) { connection in + self.releaseConnection(connection) + } for request in requests { - request.complete(with: .success(connection)) + request.complete(with: .success(lease)) } case .failRequest(let request, let error): @@ -419,7 +422,7 @@ public final class ConnectionPool< @inlinable /*private*/ func makeConnection(for request: StateMachine.ConnectionRequest, in taskGroup: inout some TaskGroupProtocol) { - taskGroup.addTask { + taskGroup.addTask_ { self.observabilityDelegate.startedConnecting(id: request.connectionID) do { @@ -468,7 +471,7 @@ public final class ConnectionPool< /*private*/ func runKeepAlive(_ connection: Connection, in taskGroup: inout some TaskGroupProtocol) { self.observabilityDelegate.keepAliveTriggered(id: connection.id) - taskGroup.addTask { + taskGroup.addTask_ { do { try await self.keepAliveBehavior.runKeepAlive(for: connection) @@ -495,7 +498,7 @@ public final class ConnectionPool< } @usableFromInline - enum TimerRunResult { + enum TimerRunResult: Sendable { case timerTriggered case timerCancelled case cancellationContinuationFinished @@ -503,15 +506,11 @@ public final class ConnectionPool< @inlinable /*private*/ func runTimer(_ timer: StateMachine.Timer, in poolGroup: inout some TaskGroupProtocol) { - poolGroup.addTask { () async -> () in + poolGroup.addTask_ { () async -> () in await withTaskGroup(of: TimerRunResult.self, returning: Void.self) { taskGroup in taskGroup.addTask { do { - #if swift(>=5.8) && os(Linux) || swift(>=5.9) try await self.clock.sleep(for: timer.duration) - #else - try await self.clock.sleep(until: self.clock.now.advanced(by: timer.duration), tolerance: nil) - #endif return .timerTriggered } catch { return .timerCancelled @@ -571,33 +570,25 @@ extension PoolConfiguration { } } -#if swift(<5.9) -// This should be removed once we support Swift 5.9+ only -extension AsyncStream { - static func makeStream( - of elementType: Element.Type = Element.self, - bufferingPolicy limit: Continuation.BufferingPolicy = .unbounded - ) -> (stream: AsyncStream, continuation: AsyncStream.Continuation) { - var continuation: AsyncStream.Continuation! - let stream = AsyncStream(bufferingPolicy: limit) { continuation = $0 } - return (stream: stream, continuation: continuation!) - } -} -#endif - @usableFromInline protocol TaskGroupProtocol { - mutating func addTask(operation: @escaping @Sendable () async -> Void) + // We need to call this `addTask_` because some Swift versions define this + // under exactly this name and others have different attributes. So let's pick + // a name that doesn't clash anywhere and implement it using the standard `addTask`. + mutating func addTask_(operation: @escaping @Sendable () async -> Void) } -#if swift(>=5.8) && os(Linux) || swift(>=5.9) @available(macOS 14.0, iOS 17.0, tvOS 17.0, watchOS 10.0, *) -extension DiscardingTaskGroup: TaskGroupProtocol {} -#endif +extension DiscardingTaskGroup: TaskGroupProtocol { + @inlinable + mutating func addTask_(operation: @escaping @Sendable () async -> Void) { + self.addTask(priority: nil, operation: operation) + } +} extension TaskGroup: TaskGroupProtocol { @inlinable - mutating func addTask(operation: @escaping @Sendable () async -> Void) { + mutating func addTask_(operation: @escaping @Sendable () async -> Void) { self.addTask(priority: nil, operation: operation) } } diff --git a/Sources/ConnectionPoolModule/ConnectionPoolError.swift b/Sources/ConnectionPoolModule/ConnectionPoolError.swift index 1f1e1d2c..3abfe778 100644 --- a/Sources/ConnectionPoolModule/ConnectionPoolError.swift +++ b/Sources/ConnectionPoolModule/ConnectionPoolError.swift @@ -1,16 +1,25 @@ public struct ConnectionPoolError: Error, Hashable { - enum Base: Error, Hashable { + @usableFromInline + enum Base: Error, Hashable, Sendable { case requestCancelled case poolShutdown } - private let base: Base + @usableFromInline + let base: Base + @inlinable init(_ base: Base) { self.base = base } /// The connection requests got cancelled - public static let requestCancelled = ConnectionPoolError(.requestCancelled) + @inlinable + public static var requestCancelled: Self { + ConnectionPoolError(.requestCancelled) + } /// The connection requests can't be fulfilled as the pool has already been shutdown - public static let poolShutdown = ConnectionPoolError(.poolShutdown) + @inlinable + public static var poolShutdown: Self { + ConnectionPoolError(.poolShutdown) + } } diff --git a/Sources/ConnectionPoolModule/ConnectionPoolObservabilityDelegate.swift b/Sources/ConnectionPoolModule/ConnectionPoolObservabilityDelegate.swift index 35f30dcb..fc1e300c 100644 --- a/Sources/ConnectionPoolModule/ConnectionPoolObservabilityDelegate.swift +++ b/Sources/ConnectionPoolModule/ConnectionPoolObservabilityDelegate.swift @@ -37,7 +37,7 @@ public protocol ConnectionPoolObservabilityDelegate: Sendable { func requestQueueDepthChanged(_ newDepth: Int) } -public struct NoOpConnectionPoolMetrics: ConnectionPoolObservabilityDelegate { +public struct NoOpConnectionPoolMetrics: ConnectionPoolObservabilityDelegate { public init(connectionIDType: ConnectionID.Type) {} public func startedConnecting(id: ConnectionID) {} diff --git a/Sources/ConnectionPoolModule/ConnectionRequest.swift b/Sources/ConnectionPoolModule/ConnectionRequest.swift index 19ed9bd2..d6654a27 100644 --- a/Sources/ConnectionPoolModule/ConnectionRequest.swift +++ b/Sources/ConnectionPoolModule/ConnectionRequest.swift @@ -5,23 +5,24 @@ public struct ConnectionRequest: ConnectionRequest public var id: ID @usableFromInline - private(set) var continuation: CheckedContinuation + private(set) var continuation: CheckedContinuation, any Error> @inlinable init( id: Int, - continuation: CheckedContinuation + continuation: CheckedContinuation, any Error> ) { self.id = id self.continuation = continuation } - public func complete(with result: Result) { + public func complete(with result: Result, ConnectionPoolError>) { self.continuation.resume(with: result) } } -fileprivate let requestIDGenerator = _ConnectionPoolModule.ConnectionIDGenerator() +@usableFromInline +let requestIDGenerator = _ConnectionPoolModule.ConnectionIDGenerator() @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) extension ConnectionPool where Request == ConnectionRequest { @@ -44,7 +45,8 @@ extension ConnectionPool where Request == ConnectionRequest { ) } - public func leaseConnection() async throws -> Connection { + @inlinable + public func leaseConnection() async throws -> ConnectionLease { let requestID = requestIDGenerator.next() let connection = try await withTaskCancellationHandler { @@ -52,7 +54,7 @@ extension ConnectionPool where Request == ConnectionRequest { throw CancellationError() } - return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation, Error>) in let request = Request( id: requestID, continuation: continuation @@ -67,9 +69,10 @@ extension ConnectionPool where Request == ConnectionRequest { return connection } + @inlinable public func withConnection(_ closure: (Connection) async throws -> Result) async throws -> Result { - let connection = try await self.leaseConnection() - defer { self.releaseConnection(connection) } - return try await closure(connection) + let lease = try await self.leaseConnection() + defer { lease.release() } + return try await closure(lease.connection) } } diff --git a/Sources/ConnectionPoolModule/NIOLock.swift b/Sources/ConnectionPoolModule/NIOLock.swift index dbc7dbe9..b6cd7164 100644 --- a/Sources/ConnectionPoolModule/NIOLock.swift +++ b/Sources/ConnectionPoolModule/NIOLock.swift @@ -24,6 +24,13 @@ import WinSDK import Glibc #elseif canImport(Musl) import Musl +#elseif canImport(Bionic) +import Bionic +#elseif canImport(WASILibc) +import WASILibc +#if canImport(wasi_pthread) +import wasi_pthread +#endif #else #error("The concurrency NIOLock module was unable to identify your C library.") #endif @@ -37,61 +44,61 @@ typealias LockPrimitive = pthread_mutex_t #endif @usableFromInline -enum LockOperations { } +enum LockOperations {} extension LockOperations { @inlinable static func create(_ mutex: UnsafeMutablePointer) { mutex.assertValidAlignment() -#if os(Windows) + #if os(Windows) InitializeSRWLock(mutex) -#else + #elseif (compiler(<6.1) && !os(WASI)) || (compiler(>=6.1) && _runtime(_multithreaded)) var attr = pthread_mutexattr_t() pthread_mutexattr_init(&attr) debugOnly { pthread_mutexattr_settype(&attr, .init(PTHREAD_MUTEX_ERRORCHECK)) } - + let err = pthread_mutex_init(mutex, &attr) precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") -#endif + #endif } - + @inlinable static func destroy(_ mutex: UnsafeMutablePointer) { mutex.assertValidAlignment() -#if os(Windows) + #if os(Windows) // SRWLOCK does not need to be free'd -#else + #elseif (compiler(<6.1) && !os(WASI)) || (compiler(>=6.1) && _runtime(_multithreaded)) let err = pthread_mutex_destroy(mutex) precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") -#endif + #endif } - + @inlinable static func lock(_ mutex: UnsafeMutablePointer) { mutex.assertValidAlignment() -#if os(Windows) + #if os(Windows) AcquireSRWLockExclusive(mutex) -#else + #elseif (compiler(<6.1) && !os(WASI)) || (compiler(>=6.1) && _runtime(_multithreaded)) let err = pthread_mutex_lock(mutex) precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") -#endif + #endif } - + @inlinable static func unlock(_ mutex: UnsafeMutablePointer) { mutex.assertValidAlignment() -#if os(Windows) + #if os(Windows) ReleaseSRWLockExclusive(mutex) -#else + #elseif (compiler(<6.1) && !os(WASI)) || (compiler(>=6.1) && _runtime(_multithreaded)) let err = pthread_mutex_unlock(mutex) precondition(err == 0, "\(#function) failed in pthread_mutex with error \(err)") -#endif + #endif } } @@ -125,49 +132,52 @@ extension LockOperations { // See also: https://github.com/apple/swift/pull/40000 @usableFromInline final class LockStorage: ManagedBuffer { - + @inlinable static func create(value: Value) -> Self { let buffer = Self.create(minimumCapacity: 1) { _ in - return value + value } - let storage = unsafeDowncast(buffer, to: Self.self) - + // Intentionally using a force cast here to avoid a miss compiliation in 5.10. + // This is as fast as an unsafeDownCast since ManagedBuffer is inlined and the optimizer + // can eliminate the upcast/downcast pair + let storage = buffer as! Self + storage.withUnsafeMutablePointers { _, lockPtr in LockOperations.create(lockPtr) } - + return storage } - + @inlinable func lock() { self.withUnsafeMutablePointerToElements { lockPtr in LockOperations.lock(lockPtr) } } - + @inlinable func unlock() { self.withUnsafeMutablePointerToElements { lockPtr in LockOperations.unlock(lockPtr) } } - + @inlinable deinit { self.withUnsafeMutablePointerToElements { lockPtr in LockOperations.destroy(lockPtr) } } - + @inlinable func withLockPrimitive(_ body: (UnsafeMutablePointer) throws -> T) rethrows -> T { try self.withUnsafeMutablePointerToElements { lockPtr in - return try body(lockPtr) + try body(lockPtr) } } - + @inlinable func withLockedValue(_ mutate: (inout Value) throws -> T) rethrows -> T { try self.withUnsafeMutablePointers { valuePtr, lockPtr in @@ -178,21 +188,18 @@ final class LockStorage: ManagedBuffer { } } -extension LockStorage: @unchecked Sendable { } - /// A threading lock based on `libpthread` instead of `libdispatch`. /// -/// - note: ``NIOLock`` has reference semantics. +/// - Note: ``NIOLock`` has reference semantics. /// /// This object provides a lock on top of a single `pthread_mutex_t`. This kind /// of lock is safe to use with `libpthread`-based threading models, such as the /// one used by NIO. On Windows, the lock is based on the substantially similar /// `SRWLOCK` type. -@usableFromInline struct NIOLock { @usableFromInline internal let _storage: LockStorage - + /// Create a new lock. @inlinable init() { @@ -219,7 +226,7 @@ struct NIOLock { @inlinable internal func withLockPrimitive(_ body: (UnsafeMutablePointer) throws -> T) rethrows -> T { - return try self._storage.withLockPrimitive(body) + try self._storage.withLockPrimitive(body) } } @@ -242,12 +249,12 @@ extension NIOLock { } @inlinable - func withLockVoid(_ body: () throws -> Void) rethrows -> Void { + func withLockVoid(_ body: () throws -> Void) rethrows { try self.withLock(body) } } -extension NIOLock: Sendable {} +extension NIOLock: @unchecked Sendable {} extension UnsafeMutablePointer { @inlinable @@ -263,6 +270,10 @@ extension UnsafeMutablePointer { /// https://forums.swift.org/t/support-debug-only-code/11037 for a discussion. @inlinable internal func debugOnly(_ body: () -> Void) { - // FIXME: duplicated with NIO. - assert({ body(); return true }()) + assert( + { + body() + return true + }() + ) } diff --git a/Sources/ConnectionPoolModule/NIOLockedValueBox.swift b/Sources/ConnectionPoolModule/NIOLockedValueBox.swift index e5a3e6a2..c9cd89e0 100644 --- a/Sources/ConnectionPoolModule/NIOLockedValueBox.swift +++ b/Sources/ConnectionPoolModule/NIOLockedValueBox.swift @@ -17,7 +17,7 @@ /// Provides locked access to `Value`. /// -/// - note: ``NIOLockedValueBox`` has reference semantics and holds the `Value` +/// - Note: ``NIOLockedValueBox`` has reference semantics and holds the `Value` /// alongside a lock behind a reference. /// /// This is no different than creating a ``Lock`` and protecting all @@ -39,8 +39,48 @@ struct NIOLockedValueBox { /// Access the `Value`, allowing mutation of it. @inlinable func withLockedValue(_ mutate: (inout Value) throws -> T) rethrows -> T { - return try self._storage.withLockedValue(mutate) + try self._storage.withLockedValue(mutate) + } + + /// Provides an unsafe view over the lock and its value. + /// + /// This can be beneficial when you require fine grained control over the lock in some + /// situations but don't want lose the benefits of ``withLockedValue(_:)`` in others by + /// switching to ``NIOLock``. + var unsafe: Unsafe { + Unsafe(_storage: self._storage) + } + + /// Provides an unsafe view over the lock and its value. + struct Unsafe { + @usableFromInline + let _storage: LockStorage + + /// Manually acquire the lock. + @inlinable + func lock() { + self._storage.lock() + } + + /// Manually release the lock. + @inlinable + func unlock() { + self._storage.unlock() + } + + /// Mutate the value, assuming the lock has been acquired manually. + /// + /// - Parameter mutate: A closure with scoped access to the value. + /// - Returns: The result of the `mutate` closure. + @inlinable + func withValueAssumingLockIsAcquired( + _ mutate: (_ value: inout Value) throws -> Result + ) rethrows -> Result { + try self._storage.withUnsafeMutablePointerToHeader { value in + try mutate(&value.pointee) + } + } } } -extension NIOLockedValueBox: Sendable where Value: Sendable {} +extension NIOLockedValueBox: @unchecked Sendable where Value: Sendable {} diff --git a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift index 833365fa..a8e97ffd 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionGroup.swift @@ -132,6 +132,12 @@ extension PoolStateMachine { @usableFromInline var info: ConnectionAvailableInfo + + @inlinable + init(use: ConnectionUse, info: ConnectionAvailableInfo) { + self.use = use + self.info = info + } } mutating func refillConnections() -> [ConnectionRequest] { @@ -592,7 +598,7 @@ extension PoolStateMachine { let newConnectionRequest: ConnectionRequest? if self.connections.count < self.minimumConcurrentConnections { - newConnectionRequest = .init(connectionID: self.generator.next()) + newConnectionRequest = self.createNewConnection() } else { newConnectionRequest = .none } @@ -623,7 +629,7 @@ extension PoolStateMachine { // MARK: - Private functions - - @usableFromInline + @inlinable /*private*/ func getConnectionUse(index: Int) -> ConnectionUse { switch index { case 0.. AvailableConnectionContext { precondition(self.connections[index].isAvailable) let use = self.getConnectionUse(index: index) diff --git a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift index 2fb68a2d..9912f13a 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine+ConnectionState.swift @@ -164,7 +164,7 @@ extension PoolStateMachine { } } - @usableFromInline + @inlinable var isLeased: Bool { switch self.state { case .leased: @@ -174,7 +174,7 @@ extension PoolStateMachine { } } - @usableFromInline + @inlinable var isConnected: Bool { switch self.state { case .idle, .leased: diff --git a/Sources/ConnectionPoolModule/PoolStateMachine.swift b/Sources/ConnectionPoolModule/PoolStateMachine.swift index 3b996033..8d995fa2 100644 --- a/Sources/ConnectionPoolModule/PoolStateMachine.swift +++ b/Sources/ConnectionPoolModule/PoolStateMachine.swift @@ -1,7 +1,9 @@ #if canImport(Darwin) import Darwin -#else +#elseif canImport(Glibc) import Glibc +#elseif canImport(Musl) +import Musl #endif @usableFromInline @@ -432,6 +434,7 @@ struct PoolStateMachine< fatalError("Unimplemented") } + @usableFromInline mutating func triggerForceShutdown() -> Action { switch self.poolState { case .running: diff --git a/Sources/ConnectionPoolModule/TinyFastSequence.swift b/Sources/ConnectionPoolModule/TinyFastSequence.swift index dff8a30b..df140c98 100644 --- a/Sources/ConnectionPoolModule/TinyFastSequence.swift +++ b/Sources/ConnectionPoolModule/TinyFastSequence.swift @@ -29,6 +29,12 @@ struct TinyFastSequence: Sequence { self.base = .none(reserveCapacity: 0) case 1: self.base = .one(collection.first!, reserveCapacity: 0) + case 2: + self.base = .two( + collection.first!, + collection[collection.index(after: collection.startIndex)], + reserveCapacity: 0 + ) default: if let collection = collection as? Array { self.base = .n(collection) @@ -46,7 +52,7 @@ struct TinyFastSequence: Sequence { case 1: self.base = .one(max2Sequence.first!, reserveCapacity: 0) case 2: - self.base = .n(Array(max2Sequence)) + self.base = .two(max2Sequence.first!, max2Sequence.second!, reserveCapacity: 0) default: fatalError() } @@ -169,7 +175,7 @@ struct TinyFastSequence: Sequence { case .n(let array): if self.index < array.endIndex { - defer { self.index += 1} + defer { self.index += 1 } return array[self.index] } return nil diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockClock.swift b/Sources/ConnectionPoolTestUtils/MockClock.swift similarity index 84% rename from Tests/ConnectionPoolModuleTests/Mocks/MockClock.swift rename to Sources/ConnectionPoolTestUtils/MockClock.swift index cd08d54e..34bf17e3 100644 --- a/Tests/ConnectionPoolModuleTests/Mocks/MockClock.swift +++ b/Sources/ConnectionPoolTestUtils/MockClock.swift @@ -1,31 +1,32 @@ -@testable import _ConnectionPoolModule +import _ConnectionPoolModule import Atomics import DequeModule +import NIOConcurrencyHelpers @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) -final class MockClock: Clock { - struct Instant: InstantProtocol, Comparable { - typealias Duration = Swift.Duration +public final class MockClock: Clock { + public struct Instant: InstantProtocol, Comparable { + public typealias Duration = Swift.Duration - func advanced(by duration: Self.Duration) -> Self { + public func advanced(by duration: Self.Duration) -> Self { .init(self.base + duration) } - func duration(to other: Self) -> Self.Duration { + public func duration(to other: Self) -> Self.Duration { self.base - other.base } private var base: Swift.Duration - init(_ base: Duration) { + public init(_ base: Duration) { self.base = base } - static func < (lhs: Self, rhs: Self) -> Bool { + public static func < (lhs: Self, rhs: Self) -> Bool { lhs.base < rhs.base } - static func == (lhs: Self, rhs: Self) -> Bool { + public static func == (lhs: Self, rhs: Self) -> Bool { lhs.base == rhs.base } } @@ -58,16 +59,18 @@ final class MockClock: Clock { var continuation: CheckedContinuation } - typealias Duration = Swift.Duration + public typealias Duration = Swift.Duration - var minimumResolution: Duration { .nanoseconds(1) } + public var minimumResolution: Duration { .nanoseconds(1) } - var now: Instant { self.stateBox.withLockedValue { $0.now } } + public var now: Instant { self.stateBox.withLockedValue { $0.now } } private let stateBox = NIOLockedValueBox(State()) private let waiterIDGenerator = ManagedAtomic(0) - func sleep(until deadline: Instant, tolerance: Duration?) async throws { + public init() {} + + public func sleep(until deadline: Instant, tolerance: Duration?) async throws { let waiterID = self.waiterIDGenerator.loadThenWrappingIncrement(ordering: .relaxed) return try await withTaskCancellationHandler { @@ -131,7 +134,7 @@ final class MockClock: Clock { } @discardableResult - func nextTimerScheduled() async -> Instant { + public func nextTimerScheduled() async -> Instant { await withCheckedContinuation { (continuation: CheckedContinuation) in let instant = self.stateBox.withLockedValue { state -> Instant? in if let scheduled = state.nextDeadlines.popFirst() { @@ -149,7 +152,7 @@ final class MockClock: Clock { } } - func advance(to deadline: Instant) { + public func advance(to deadline: Instant) { let waiters = self.stateBox.withLockedValue { state -> ArraySlice in precondition(deadline > state.now, "Time can only move forward") state.now = deadline diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockConnection.swift b/Sources/ConnectionPoolTestUtils/MockConnection.swift similarity index 86% rename from Tests/ConnectionPoolModuleTests/Mocks/MockConnection.swift rename to Sources/ConnectionPoolTestUtils/MockConnection.swift index f826ea04..db5c3ef7 100644 --- a/Tests/ConnectionPoolModuleTests/Mocks/MockConnection.swift +++ b/Sources/ConnectionPoolTestUtils/MockConnection.swift @@ -1,11 +1,11 @@ +import _ConnectionPoolModule import DequeModule -@testable import _ConnectionPoolModule +import NIOConcurrencyHelpers -// Sendability enforced through the lock -final class MockConnection: PooledConnection, Sendable { - typealias ID = Int +public final class MockConnection: PooledConnection, Sendable { + public typealias ID = Int - let id: ID + public let id: ID private enum State { case running([CheckedContinuation], [@Sendable ((any Error)?) -> ()]) @@ -15,11 +15,11 @@ final class MockConnection: PooledConnection, Sendable { private let lock: NIOLockedValueBox = NIOLockedValueBox(.running([], [])) - init(id: Int) { + public init(id: Int) { self.id = id } - var signalToClose: Void { + public var signalToClose: Void { get async throws { try await withCheckedThrowingContinuation { continuation in let runRightAway = self.lock.withLockedValue { state -> Bool in @@ -41,7 +41,7 @@ final class MockConnection: PooledConnection, Sendable { } } - func onClose(_ closure: @escaping @Sendable ((any Error)?) -> ()) { + public func onClose(_ closure: @escaping @Sendable ((any Error)?) -> ()) { let enqueued = self.lock.withLockedValue { state -> Bool in switch state { case .closed: @@ -64,7 +64,7 @@ final class MockConnection: PooledConnection, Sendable { } } - func close() { + public func close() { let continuations = self.lock.withLockedValue { state -> [CheckedContinuation] in switch state { case .running(let continuations, let callbacks): @@ -81,7 +81,7 @@ final class MockConnection: PooledConnection, Sendable { } } - func closeIfClosing() { + public func closeIfClosing() { let callbacks = self.lock.withLockedValue { state -> [@Sendable ((any Error)?) -> ()] in switch state { case .running, .closed: @@ -100,7 +100,7 @@ final class MockConnection: PooledConnection, Sendable { } extension MockConnection: CustomStringConvertible { - var description: String { + public var description: String { let state = self.lock.withLockedValue { $0 } return "MockConnection(id: \(self.id), state: \(state))" } diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockConnectionFactory.swift b/Sources/ConnectionPoolTestUtils/MockConnectionFactory.swift similarity index 71% rename from Tests/ConnectionPoolModuleTests/Mocks/MockConnectionFactory.swift rename to Sources/ConnectionPoolTestUtils/MockConnectionFactory.swift index eec2e7c3..936b47cc 100644 --- a/Tests/ConnectionPoolModuleTests/Mocks/MockConnectionFactory.swift +++ b/Sources/ConnectionPoolTestUtils/MockConnectionFactory.swift @@ -1,14 +1,15 @@ -@testable import _ConnectionPoolModule +import _ConnectionPoolModule import DequeModule +import NIOConcurrencyHelpers @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) -final class MockConnectionFactory where Clock.Duration == Duration { - typealias ConnectionIDGenerator = _ConnectionPoolModule.ConnectionIDGenerator - typealias Request = ConnectionRequest - typealias KeepAliveBehavior = MockPingPongBehavior - typealias MetricsDelegate = NoOpConnectionPoolMetrics - typealias ConnectionID = Int - typealias Connection = MockConnection +public final class MockConnectionFactory: Sendable where Clock.Duration == Duration { + public typealias ConnectionIDGenerator = _ConnectionPoolModule.ConnectionIDGenerator + public typealias Request = ConnectionRequest + public typealias KeepAliveBehavior = MockPingPongBehavior + public typealias MetricsDelegate = NoOpConnectionPoolMetrics + public typealias ConnectionID = Int + public typealias Connection = MockConnection let stateBox = NIOLockedValueBox(State()) @@ -20,18 +21,33 @@ final class MockConnectionFactory where Clock.Duratio var runningConnections = [ConnectionID: Connection]() } - var pendingConnectionAttemptsCount: Int { + let autoMaxStreams: UInt16? + + public init(autoMaxStreams: UInt16? = nil) { + self.autoMaxStreams = autoMaxStreams + } + + public var pendingConnectionAttemptsCount: Int { self.stateBox.withLockedValue { $0.attempts.count } } - var runningConnections: [Connection] { + public var runningConnections: [Connection] { self.stateBox.withLockedValue { Array($0.runningConnections.values) } } - func makeConnection( + public func makeConnection( id: Int, for pool: ConnectionPool, NoOpConnectionPoolMetrics, Clock> ) async throws -> ConnectionAndMetadata { + if let autoMaxStreams = self.autoMaxStreams { + let connection = MockConnection(id: id) + Task { + try? await connection.signalToClose + connection.closeIfClosing() + } + return .init(connection: connection, maximalStreamsOnConnection: autoMaxStreams) + } + // we currently don't support cancellation when creating a connection let result = try await withCheckedThrowingContinuation { (checkedContinuation: CheckedContinuation<(MockConnection, UInt16), any Error>) in let waiter = self.stateBox.withLockedValue { state -> (CheckedContinuation<(ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>), Never>)? in @@ -52,7 +68,7 @@ final class MockConnectionFactory where Clock.Duratio } @discardableResult - func nextConnectAttempt(_ closure: (ConnectionID) async throws -> UInt16) async rethrows -> Connection { + public func nextConnectAttempt(_ closure: (ConnectionID) async throws -> UInt16) async rethrows -> Connection { let (connectionID, continuation) = await withCheckedContinuation { (continuation: CheckedContinuation<(ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>), Never>) in let attempt = self.stateBox.withLockedValue { state -> (ConnectionID, CheckedContinuation<(MockConnection, UInt16), any Error>)? in if let attempt = state.attempts.popFirst() { diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockPingPongBehaviour.swift b/Sources/ConnectionPoolTestUtils/MockPingPongBehaviour.swift similarity index 82% rename from Tests/ConnectionPoolModuleTests/Mocks/MockPingPongBehaviour.swift rename to Sources/ConnectionPoolTestUtils/MockPingPongBehaviour.swift index 637f096c..de1a7275 100644 --- a/Tests/ConnectionPoolModuleTests/Mocks/MockPingPongBehaviour.swift +++ b/Sources/ConnectionPoolTestUtils/MockPingPongBehaviour.swift @@ -1,9 +1,10 @@ -@testable import _ConnectionPoolModule +import _ConnectionPoolModule import DequeModule +import NIOConcurrencyHelpers @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) -final class MockPingPongBehavior: ConnectionKeepAliveBehavior { - let keepAliveFrequency: Duration? +public final class MockPingPongBehavior: ConnectionKeepAliveBehavior { + public let keepAliveFrequency: Duration? let stateBox = NIOLockedValueBox(State()) @@ -13,11 +14,11 @@ final class MockPingPongBehavior: ConnectionKeepAl var waiter = Deque), Never>>() } - init(keepAliveFrequency: Duration?, connectionType: Connection.Type) { + public init(keepAliveFrequency: Duration?, connectionType: Connection.Type) { self.keepAliveFrequency = keepAliveFrequency } - func runKeepAlive(for connection: Connection) async throws { + public func runKeepAlive(for connection: Connection) async throws { precondition(self.keepAliveFrequency != nil) // we currently don't support cancellation when creating a connection @@ -40,7 +41,7 @@ final class MockPingPongBehavior: ConnectionKeepAl } @discardableResult - func nextKeepAlive(_ closure: (Connection) async throws -> Bool) async rethrows -> Connection { + public func nextKeepAlive(_ closure: (Connection) async throws -> Bool) async rethrows -> Connection { let (connection, continuation) = await withCheckedContinuation { (continuation: CheckedContinuation<(Connection, CheckedContinuation), Never>) in let run = self.stateBox.withLockedValue { state -> (Connection, CheckedContinuation)? in if let run = state.runs.popFirst() { diff --git a/Sources/ConnectionPoolTestUtils/MockRequest.swift b/Sources/ConnectionPoolTestUtils/MockRequest.swift new file mode 100644 index 00000000..3dd8b0fb --- /dev/null +++ b/Sources/ConnectionPoolTestUtils/MockRequest.swift @@ -0,0 +1,27 @@ +import _ConnectionPoolModule + +public final class MockRequest: ConnectionRequestProtocol, Hashable, Sendable { + public struct ID: Hashable, Sendable { + var objectID: ObjectIdentifier + + init(_ request: MockRequest) { + self.objectID = ObjectIdentifier(request) + } + } + + public init(connectionType: Connection.Type = Connection.self) {} + + public var id: ID { ID(self) } + + public static func ==(lhs: MockRequest, rhs: MockRequest) -> Bool { + lhs.id == rhs.id + } + + public func hash(into hasher: inout Hasher) { + hasher.combine(self.id) + } + + public func complete(with: Result, ConnectionPoolError>) { + + } +} diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift b/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift index dd0f5404..b260723a 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift @@ -192,9 +192,22 @@ extension PostgresConnection { /// - Parameters: /// - channel: The `NIOCore/Channel` to use. The channel must already be active and connected to an /// endpoint (i.e. `NIOCore/Channel/isActive` must be `true`). - /// - tls: The TLS mode to use. Defaults to ``TLS-swift.struct/disable``. + /// - tls: The TLS mode to use. + public init(establishedChannel channel: Channel, tls: PostgresConnection.Configuration.TLS, username: String, password: String?, database: String?) { + self.init(endpointInfo: .configureChannel(channel), tls: tls, username: username, password: password, database: database) + } + + /// Create a configuration for establishing a connection to a Postgres server over a preestablished + /// `NIOCore/Channel`. + /// + /// This is provided for calling code which wants to manage the underlying connection transport on its + /// own, such as when tunneling a connection through SSH. + /// + /// - Parameters: + /// - channel: The `NIOCore/Channel` to use. The channel must already be active and connected to an + /// endpoint (i.e. `NIOCore/Channel/isActive` must be `true`). public init(establishedChannel channel: Channel, username: String, password: String?, database: String?) { - self.init(endpointInfo: .configureChannel(channel), tls: .disable, username: username, password: password, database: database) + self.init(establishedChannel: channel, tls: .disable, username: username, password: password, database: database) } // MARK: - Implementation details diff --git a/Sources/PostgresNIO/Connection/PostgresConnection.swift b/Sources/PostgresNIO/Connection/PostgresConnection.swift index eb9dc791..e267d8f9 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection.swift @@ -60,18 +60,18 @@ public final class PostgresConnection: @unchecked Sendable { func start(configuration: InternalConfiguration) -> EventLoopFuture { // 1. configure handlers - let configureSSLCallback: ((Channel) throws -> ())? + let configureSSLCallback: ((Channel, PostgresChannelHandler) throws -> ())? switch configuration.tls.base { case .prefer(let context), .require(let context): - configureSSLCallback = { channel in + configureSSLCallback = { channel, postgresChannelHandler in channel.eventLoop.assertInEventLoop() let sslHandler = try NIOSSLClientHandler( context: context, serverHostname: configuration.serverNameForTLS ) - try channel.pipeline.syncOperations.addHandler(sslHandler, position: .first) + try channel.pipeline.syncOperations.addHandler(sslHandler, position: .before(postgresChannelHandler)) } case .disable: configureSSLCallback = nil @@ -530,6 +530,110 @@ extension PostgresConnection { throw error // rethrow with more metadata } } + + #if compiler(>=6.0) + /// Puts the connection into an open transaction state, for the provided `closure`'s lifetime. + /// + /// The function starts a transaction by running a `BEGIN` query on the connection against the database. It then + /// lends the connection to the user provided closure. The user can then modify the database as they wish. If the user + /// provided closure returns successfully, the function will attempt to commit the changes by running a `COMMIT` + /// query against the database. If the user provided closure throws an error, the function will attempt to rollback the + /// changes made within the closure. + /// + /// - Parameters: + /// - logger: The `Logger` to log into for the transaction. + /// - file: The file, the transaction was started in. Used for better error reporting. + /// - line: The line, the transaction was started in. Used for better error reporting. + /// - closure: The user provided code to modify the database. Use the provided connection to run queries. + /// The connection must stay in the transaction mode. Otherwise this method will throw! + /// - Returns: The closure's return value. + public func withTransaction( + logger: Logger, + file: String = #file, + line: Int = #line, + isolation: isolated (any Actor)? = #isolation, + // DO NOT FIX THE WHITESPACE IN THE NEXT LINE UNTIL 5.10 IS UNSUPPORTED + // https://github.com/swiftlang/swift/issues/79285 + _ process: (PostgresConnection) async throws -> sending Result) async throws -> sending Result { + do { + try await self.query("BEGIN;", logger: logger) + } catch { + throw PostgresTransactionError(file: file, line: line, beginError: error) + } + + var closureHasFinished: Bool = false + do { + let value = try await process(self) + closureHasFinished = true + try await self.query("COMMIT;", logger: logger) + return value + } catch { + var transactionError = PostgresTransactionError(file: file, line: line) + if !closureHasFinished { + transactionError.closureError = error + do { + try await self.query("ROLLBACK;", logger: logger) + } catch { + transactionError.rollbackError = error + } + } else { + transactionError.commitError = error + } + + throw transactionError + } + } + #else + /// Puts the connection into an open transaction state, for the provided `closure`'s lifetime. + /// + /// The function starts a transaction by running a `BEGIN` query on the connection against the database. It then + /// lends the connection to the user provided closure. The user can then modify the database as they wish. If the user + /// provided closure returns successfully, the function will attempt to commit the changes by running a `COMMIT` + /// query against the database. If the user provided closure throws an error, the function will attempt to rollback the + /// changes made within the closure. + /// + /// - Parameters: + /// - logger: The `Logger` to log into for the transaction. + /// - file: The file, the transaction was started in. Used for better error reporting. + /// - line: The line, the transaction was started in. Used for better error reporting. + /// - closure: The user provided code to modify the database. Use the provided connection to run queries. + /// The connection must stay in the transaction mode. Otherwise this method will throw! + /// - Returns: The closure's return value. + public func withTransaction( + logger: Logger, + file: String = #file, + line: Int = #line, + _ process: (PostgresConnection) async throws -> Result + ) async throws -> Result { + do { + try await self.query("BEGIN;", logger: logger) + } catch { + throw PostgresTransactionError(file: file, line: line, beginError: error) + } + + var closureHasFinished: Bool = false + do { + let value = try await process(self) + closureHasFinished = true + try await self.query("COMMIT;", logger: logger) + return value + } catch { + var transactionError = PostgresTransactionError(file: file, line: line) + if !closureHasFinished { + transactionError.closureError = error + do { + try await self.query("ROLLBACK;", logger: logger) + } catch { + transactionError.rollbackError = error + } + } else { + transactionError.commitError = error + } + + throw transactionError + } + } + #endif } // MARK: EventLoopFuture interface diff --git a/Sources/PostgresNIO/Docs.docc/theme-settings.json b/Sources/PostgresNIO/Docs.docc/theme-settings.json index dda76197..38914a04 100644 --- a/Sources/PostgresNIO/Docs.docc/theme-settings.json +++ b/Sources/PostgresNIO/Docs.docc/theme-settings.json @@ -1,16 +1,19 @@ { "theme": { - "aside": { "border-radius": "6px", "border-style": "double", "border-width": "3px" }, + "aside": { "border-radius": "16px", "border-style": "double", "border-width": "3px" }, "border-radius": "0", "button": { "border-radius": "16px", "border-width": "1px", "border-style": "solid" }, "code": { "border-radius": "16px", "border-width": "1px", "border-style": "solid" }, "color": { + "fill": { "dark": "#000", "light": "#fff" }, "psqlnio": "#336791", "documentation-intro-fill": "radial-gradient(circle at top, var(--color-psqlnio) 30%, #000 100%)", "documentation-intro-accent": "var(--color-psqlnio)", + "documentation-intro-eyebrow": "white", + "documentation-intro-figure": "white", + "documentation-intro-title": "white", "logo-base": { "dark": "#fff", "light": "#000" }, - "logo-shape": { "dark": "#000", "light": "#fff" }, - "fill": { "dark": "#000", "light": "#fff" } + "logo-shape": { "dark": "#000", "light": "#fff" } }, "icons": { "technology": "/postgresnio/images/vapor-postgresnio-logo.svg" } }, diff --git a/Sources/PostgresNIO/Message/PostgresMessage+Identifier.swift b/Sources/PostgresNIO/Message/PostgresMessage+Identifier.swift index 786b91ef..5d111e3b 100644 --- a/Sources/PostgresNIO/Message/PostgresMessage+Identifier.swift +++ b/Sources/PostgresNIO/Message/PostgresMessage+Identifier.swift @@ -4,7 +4,7 @@ extension PostgresMessage { /// Identifies an incoming or outgoing postgres message. Sent as the first byte, before the message size. /// Values are not unique across all identifiers, meaning some messages will require keeping state to identify. @available(*, deprecated, message: "Will be removed from public API.") - public struct Identifier: ExpressibleByIntegerLiteral, Equatable, CustomStringConvertible { + public struct Identifier: Sendable, ExpressibleByIntegerLiteral, Equatable, CustomStringConvertible { // special public static let none: Identifier = 0x00 // special diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index 9d264bcc..8560b948 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -752,6 +752,12 @@ struct ConnectionStateMachine { return self.modify(with: action) } + mutating func copyInResponseReceived( + _ copyInResponse: PostgresBackendMessage.CopyInResponse + ) -> ConnectionAction { + return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.copyInResponse(copyInResponse))) + } + mutating func emptyQueryResponseReceived() -> ConnectionAction { guard case .extendedQuery(var queryState, let connectionContext) = self.state, !queryState.isComplete else { return self.closeConnectionAndCleanup(.unexpectedBackendMessage(.emptyQueryResponse)) diff --git a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift index 78f0d202..5708b6b9 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift @@ -10,7 +10,8 @@ struct ExtendedQueryStateMachine { case parameterDescriptionReceived(ExtendedQueryContext) case rowDescriptionReceived(ExtendedQueryContext, [RowDescription.Column]) case noDataMessageReceived(ExtendedQueryContext) - + case emptyQueryResponseReceived + /// A state that is used if a noData message was received before. If a row description was received `bufferingRows` is /// used after receiving a `bindComplete` message case bindCompleteReceived(ExtendedQueryContext) @@ -90,7 +91,7 @@ struct ExtendedQueryStateMachine { mutating func cancel() -> Action { switch self.state { case .initialized: - preconditionFailure("Start must be called immediatly after the query was created") + preconditionFailure("Start must be called immediately after the query was created") case .messagesSent(let queryContext), .parseCompleteReceived(let queryContext), @@ -122,7 +123,7 @@ struct ExtendedQueryStateMachine { return .forwardStreamError(.queryCancelled, read: true) } - case .commandComplete, .error, .drain: + case .commandComplete, .emptyQueryResponseReceived, .error, .drain: // the stream has already finished. return .wait @@ -229,6 +230,7 @@ struct ExtendedQueryStateMachine { .messagesSent, .parseCompleteReceived, .parameterDescriptionReceived, + .emptyQueryResponseReceived, .bindCompleteReceived, .streaming, .drain, @@ -268,6 +270,7 @@ struct ExtendedQueryStateMachine { .parseCompleteReceived, .parameterDescriptionReceived, .noDataMessageReceived, + .emptyQueryResponseReceived, .rowDescriptionReceived, .bindCompleteReceived, .commandComplete, @@ -285,7 +288,7 @@ struct ExtendedQueryStateMachine { case .unnamed(_, let eventLoopPromise), .executeStatement(_, let eventLoopPromise): return self.avoidingStateMachineCoW { state -> Action in state = .commandComplete(commandTag: commandTag) - let result = QueryResult(value: .noRows(commandTag), logger: context.logger) + let result = QueryResult(value: .noRows(.tag(commandTag)), logger: context.logger) return .succeedQuery(eventLoopPromise, with: result) } @@ -309,6 +312,7 @@ struct ExtendedQueryStateMachine { .parseCompleteReceived, .parameterDescriptionReceived, .noDataMessageReceived, + .emptyQueryResponseReceived, .rowDescriptionReceived, .commandComplete, .error: @@ -318,8 +322,29 @@ struct ExtendedQueryStateMachine { } } + mutating func copyInResponseReceived( + _ copyInResponse: PostgresBackendMessage.CopyInResponse + ) -> Action { + return self.setAndFireError(.unexpectedBackendMessage(.copyInResponse(copyInResponse))) + } + mutating func emptyQueryResponseReceived() -> Action { - preconditionFailure("Unimplemented") + guard case .bindCompleteReceived(let queryContext) = self.state else { + return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse)) + } + + switch queryContext.query { + case .unnamed(_, let eventLoopPromise), + .executeStatement(_, let eventLoopPromise): + return self.avoidingStateMachineCoW { state -> Action in + state = .emptyQueryResponseReceived + let result = QueryResult(value: .noRows(.emptyResponse), logger: queryContext.logger) + return .succeedQuery(eventLoopPromise, with: result) + } + + case .prepareStatement(_, _, _, _): + return self.setAndFireError(.unexpectedBackendMessage(.emptyQueryResponse)) + } } mutating func errorReceived(_ errorMessage: PostgresBackendMessage.ErrorResponse) -> Action { @@ -336,7 +361,7 @@ struct ExtendedQueryStateMachine { return self.setAndFireError(error) case .streaming, .drain: return self.setAndFireError(error) - case .commandComplete: + case .commandComplete, .emptyQueryResponseReceived: return self.setAndFireError(.unexpectedBackendMessage(.error(errorMessage))) case .error: preconditionFailure(""" @@ -382,6 +407,7 @@ struct ExtendedQueryStateMachine { .parseCompleteReceived, .parameterDescriptionReceived, .noDataMessageReceived, + .emptyQueryResponseReceived, .rowDescriptionReceived, .bindCompleteReceived: preconditionFailure("Requested to consume next row without anything going on.") @@ -405,6 +431,7 @@ struct ExtendedQueryStateMachine { .parseCompleteReceived, .parameterDescriptionReceived, .noDataMessageReceived, + .emptyQueryResponseReceived, .rowDescriptionReceived, .bindCompleteReceived: return .wait @@ -449,6 +476,7 @@ struct ExtendedQueryStateMachine { } case .initialized, .commandComplete, + .emptyQueryResponseReceived, .drain, .error: // we already have the complete stream received, now we are waiting for a @@ -495,7 +523,7 @@ struct ExtendedQueryStateMachine { return .forwardStreamError(error, read: true) } - case .commandComplete, .error: + case .commandComplete, .emptyQueryResponseReceived, .error: preconditionFailure(""" This state must not be reached. If the query `.isComplete`, the ConnectionStateMachine must not send any further events to the substate machine. @@ -507,7 +535,7 @@ struct ExtendedQueryStateMachine { var isComplete: Bool { switch self.state { - case .commandComplete, .error: + case .commandComplete, .emptyQueryResponseReceived, .error: return true case .noDataMessageReceived(let context), .rowDescriptionReceived(let context, _): diff --git a/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift b/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift index 41091ab3..7e8376a7 100644 --- a/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift +++ b/Sources/PostgresNIO/New/Data/String+PostgresCodable.swift @@ -29,6 +29,12 @@ extension String: PostgresDecodable { context: PostgresDecodingContext ) throws { switch (format, type) { + case (.binary, .jsonb): + // Discard the version byte + guard let version = buffer.readInteger(as: UInt8.self), version == 1 else { + throw PostgresDecodingError.Code.failure + } + self = buffer.readString(length: buffer.readableBytes)! case (_, .varchar), (_, .bpchar), (_, .text), @@ -36,13 +42,20 @@ extension String: PostgresDecodable { // we can force unwrap here, since this method only fails if there are not enough // bytes available. self = buffer.readString(length: buffer.readableBytes)! + case (_, .uuid): guard let uuid = try? UUID(from: &buffer, type: .uuid, format: format, context: context) else { throw PostgresDecodingError.Code.failure } self = uuid.uuidString + default: - throw PostgresDecodingError.Code.typeMismatch + // We should eagerly try to convert any datatype into a String. For example the oid + // for ltree isn't static. For this reason we should just try to convert anything. + guard let string = buffer.readString(length: buffer.readableBytes) else { + throw PostgresDecodingError.Code.typeMismatch + } + self = string } } } diff --git a/Sources/PostgresNIO/New/Messages/CopyInMessage.swift b/Sources/PostgresNIO/New/Messages/CopyInMessage.swift new file mode 100644 index 00000000..46dec648 --- /dev/null +++ b/Sources/PostgresNIO/New/Messages/CopyInMessage.swift @@ -0,0 +1,44 @@ +extension PostgresBackendMessage { + struct CopyInResponse: Hashable { + enum Format: Int8 { + case textual = 0 + case binary = 1 + } + + let format: Format + let columnFormats: [Format] + + static func decode(from buffer: inout ByteBuffer) throws -> Self { + guard let rawFormat = buffer.readInteger(endianness: .big, as: Int8.self) else { + throw PSQLPartialDecodingError.expectedAtLeastNRemainingBytes(1, actual: buffer.readableBytes) + } + guard let format = Format(rawValue: rawFormat) else { + throw PSQLPartialDecodingError.unexpectedValue(value: rawFormat) + } + + guard let numColumns = buffer.readInteger(endianness: .big, as: Int16.self) else { + throw PSQLPartialDecodingError.expectedAtLeastNRemainingBytes(2, actual: buffer.readableBytes) + } + var columnFormatCodes: [Format] = [] + columnFormatCodes.reserveCapacity(Int(numColumns)) + + for _ in 0.. (stream: AsyncThrowingStream, continuation: AsyncThrowingStream.Continuation) where Failure == Error { - var continuation: AsyncThrowingStream.Continuation! - let stream = AsyncThrowingStream(bufferingPolicy: limit) { continuation = $0 } - return (stream: stream, continuation: continuation!) - } - } - #endif diff --git a/Sources/PostgresNIO/New/PSQLEventsHandler.swift b/Sources/PostgresNIO/New/PSQLEventsHandler.swift index 2bf0d6d8..0f426f20 100644 --- a/Sources/PostgresNIO/New/PSQLEventsHandler.swift +++ b/Sources/PostgresNIO/New/PSQLEventsHandler.swift @@ -68,10 +68,8 @@ final class PSQLEventsHandler: ChannelInboundHandler { case .authenticated: break } - case TLSUserEvent.shutdownCompleted: - break default: - preconditionFailure() + context.fireUserInboundEventTriggered(event) } } diff --git a/Sources/PostgresNIO/New/PSQLRowStream.swift b/Sources/PostgresNIO/New/PSQLRowStream.swift index b7f2d4fb..ee925d0e 100644 --- a/Sources/PostgresNIO/New/PSQLRowStream.swift +++ b/Sources/PostgresNIO/New/PSQLRowStream.swift @@ -3,7 +3,7 @@ import Logging struct QueryResult { enum Value: Equatable { - case noRows(String) + case noRows(PSQLRowStream.StatementSummary) case rowDescription([RowDescription.Column]) } @@ -16,25 +16,30 @@ struct QueryResult { final class PSQLRowStream: @unchecked Sendable { private typealias AsyncSequenceSource = NIOThrowingAsyncSequenceProducer.Source + enum StatementSummary: Equatable { + case tag(String) + case emptyResponse + } + enum Source { case stream([RowDescription.Column], PSQLRowsDataSource) - case noRows(Result) + case noRows(Result) } let eventLoop: EventLoop let logger: Logger - + private enum BufferState { case streaming(buffer: CircularBuffer, dataSource: PSQLRowsDataSource) - case finished(buffer: CircularBuffer, commandTag: String) + case finished(buffer: CircularBuffer, summary: StatementSummary) case failure(Error) } - + private enum DownstreamState { case waitingForConsumer(BufferState) case iteratingRows(onRow: (PostgresRow) throws -> (), EventLoopPromise, PSQLRowsDataSource) case waitingForAll([PostgresRow], EventLoopPromise<[PostgresRow]>, PSQLRowsDataSource) - case consumed(Result) + case consumed(Result) case asyncSequence(AsyncSequenceSource, PSQLRowsDataSource, onFinish: @Sendable () -> ()) } @@ -52,9 +57,9 @@ final class PSQLRowStream: @unchecked Sendable { case .stream(let rowDescription, let dataSource): self.rowDescription = rowDescription bufferState = .streaming(buffer: .init(), dataSource: dataSource) - case .noRows(.success(let commandTag)): + case .noRows(.success(let summary)): self.rowDescription = [] - bufferState = .finished(buffer: .init(), commandTag: commandTag) + bufferState = .finished(buffer: .init(), summary: summary) case .noRows(.failure(let error)): self.rowDescription = [] bufferState = .failure(error) @@ -98,12 +103,12 @@ final class PSQLRowStream: @unchecked Sendable { self.downstreamState = .asyncSequence(source, dataSource, onFinish: onFinish) self.executeActionBasedOnYieldResult(yieldResult, source: dataSource) - case .finished(let buffer, let commandTag): + case .finished(let buffer, let summary): _ = source.yield(contentsOf: buffer) source.finish() onFinish() - self.downstreamState = .consumed(.success(commandTag)) - + self.downstreamState = .consumed(.success(summary)) + case .failure(let error): source.finish(error) self.downstreamState = .consumed(.failure(error)) @@ -190,12 +195,12 @@ final class PSQLRowStream: @unchecked Sendable { dataSource.request(for: self) return promise.futureResult - case .finished(let buffer, let commandTag): + case .finished(let buffer, let summary): let rows = buffer.map { PostgresRow(data: $0, lookupTable: self.lookupTable, columns: self.rowDescription) } - self.downstreamState = .consumed(.success(commandTag)) + self.downstreamState = .consumed(.success(summary)) return self.eventLoop.makeSucceededFuture(rows) case .failure(let error): @@ -247,8 +252,8 @@ final class PSQLRowStream: @unchecked Sendable { } return promise.futureResult - - case .finished(let buffer, let commandTag): + + case .finished(let buffer, let summary): do { for data in buffer { let row = PostgresRow( @@ -259,7 +264,7 @@ final class PSQLRowStream: @unchecked Sendable { try onRow(row) } - self.downstreamState = .consumed(.success(commandTag)) + self.downstreamState = .consumed(.success(summary)) return self.eventLoop.makeSucceededVoidFuture() } catch { self.downstreamState = .consumed(.failure(error)) @@ -292,7 +297,7 @@ final class PSQLRowStream: @unchecked Sendable { case .waitingForConsumer(.finished), .waitingForConsumer(.failure): preconditionFailure("How can new rows be received, if an end was already signalled?") - + case .iteratingRows(let onRow, let promise, let dataSource): do { for data in newRows { @@ -347,25 +352,25 @@ final class PSQLRowStream: @unchecked Sendable { private func receiveEnd(_ commandTag: String) { switch self.downstreamState { case .waitingForConsumer(.streaming(buffer: let buffer, _)): - self.downstreamState = .waitingForConsumer(.finished(buffer: buffer, commandTag: commandTag)) - - case .waitingForConsumer(.finished), .waitingForConsumer(.failure): + self.downstreamState = .waitingForConsumer(.finished(buffer: buffer, summary: .tag(commandTag))) + + case .waitingForConsumer(.finished), .waitingForConsumer(.failure), .consumed(.success(.emptyResponse)): preconditionFailure("How can we get another end, if an end was already signalled?") case .iteratingRows(_, let promise, _): - self.downstreamState = .consumed(.success(commandTag)) + self.downstreamState = .consumed(.success(.tag(commandTag))) promise.succeed(()) case .waitingForAll(let rows, let promise, _): - self.downstreamState = .consumed(.success(commandTag)) + self.downstreamState = .consumed(.success(.tag(commandTag))) promise.succeed(rows) case .asyncSequence(let source, _, let onFinish): - self.downstreamState = .consumed(.success(commandTag)) + self.downstreamState = .consumed(.success(.tag(commandTag))) source.finish() onFinish() - case .consumed: + case .consumed(.success(.tag)), .consumed(.failure): break } } @@ -375,7 +380,7 @@ final class PSQLRowStream: @unchecked Sendable { case .waitingForConsumer(.streaming): self.downstreamState = .waitingForConsumer(.failure(error)) - case .waitingForConsumer(.finished), .waitingForConsumer(.failure): + case .waitingForConsumer(.finished), .waitingForConsumer(.failure), .consumed(.success(.emptyResponse)): preconditionFailure("How can we get another end, if an end was already signalled?") case .iteratingRows(_, let promise, _): @@ -391,7 +396,7 @@ final class PSQLRowStream: @unchecked Sendable { consumer.finish(error) onFinish() - case .consumed: + case .consumed(.success(.tag)), .consumed(.failure): break } } @@ -413,10 +418,15 @@ final class PSQLRowStream: @unchecked Sendable { } var commandTag: String { - guard case .consumed(.success(let commandTag)) = self.downstreamState else { + guard case .consumed(.success(let consumed)) = self.downstreamState else { preconditionFailure("commandTag may only be called if all rows have been consumed") } - return commandTag + switch consumed { + case .tag(let tag): + return tag + case .emptyResponse: + return "" + } } } diff --git a/Sources/PostgresNIO/New/PSQLTask.swift b/Sources/PostgresNIO/New/PSQLTask.swift index 363f9394..6106fd21 100644 --- a/Sources/PostgresNIO/New/PSQLTask.swift +++ b/Sources/PostgresNIO/New/PSQLTask.swift @@ -1,7 +1,7 @@ import Logging import NIOCore -enum HandlerTask { +enum HandlerTask: Sendable { case extendedQuery(ExtendedQueryContext) case closeCommand(CloseCommandContext) case startListening(NotificationListener) @@ -31,7 +31,7 @@ enum PSQLTask { } } -final class ExtendedQueryContext { +final class ExtendedQueryContext: Sendable { enum Query { case unnamed(PostgresQuery, EventLoopPromise) case executeStatement(PSQLExecuteStatement, EventLoopPromise) @@ -100,14 +100,15 @@ final class PreparedStatementContext: Sendable { } } -final class CloseCommandContext { +final class CloseCommandContext: Sendable { let target: CloseTarget let logger: Logger let promise: EventLoopPromise - init(target: CloseTarget, - logger: Logger, - promise: EventLoopPromise + init( + target: CloseTarget, + logger: Logger, + promise: EventLoopPromise ) { self.target = target self.logger = logger diff --git a/Sources/PostgresNIO/New/PostgresBackendMessage.swift b/Sources/PostgresNIO/New/PostgresBackendMessage.swift index 792beec3..030d651d 100644 --- a/Sources/PostgresNIO/New/PostgresBackendMessage.swift +++ b/Sources/PostgresNIO/New/PostgresBackendMessage.swift @@ -29,6 +29,7 @@ enum PostgresBackendMessage: Hashable { case bindComplete case closeComplete case commandComplete(String) + case copyInResponse(CopyInResponse) case dataRow(DataRow) case emptyQueryResponse case error(ErrorResponse) @@ -96,6 +97,9 @@ extension PostgresBackendMessage { } return .commandComplete(commandTag) + case .copyInResponse: + return try .copyInResponse(.decode(from: &buffer)) + case .dataRow: return try .dataRow(.decode(from: &buffer)) @@ -131,9 +135,9 @@ extension PostgresBackendMessage { case .rowDescription: return try .rowDescription(.decode(from: &buffer)) - - case .copyData, .copyDone, .copyInResponse, .copyOutResponse, .copyBothResponse, .functionCallResponse, .negotiateProtocolVersion: - preconditionFailure() + + case .copyData, .copyDone, .copyOutResponse, .copyBothResponse, .functionCallResponse, .negotiateProtocolVersion: + throw PSQLPartialDecodingError.unknownMessageKind(messageID) } } } @@ -151,6 +155,8 @@ extension PostgresBackendMessage: CustomDebugStringConvertible { return ".closeComplete" case .commandComplete(let commandTag): return ".commandComplete(\(String(reflecting: commandTag)))" + case .copyInResponse(let copyInResponse): + return ".copyInResponse(\(String(reflecting: copyInResponse)))" case .dataRow(let dataRow): return ".dataRow(\(String(reflecting: dataRow)))" case .emptyQueryResponse: diff --git a/Sources/PostgresNIO/New/PostgresBackendMessageDecoder.swift b/Sources/PostgresNIO/New/PostgresBackendMessageDecoder.swift index 6f6be7ec..155c6714 100644 --- a/Sources/PostgresNIO/New/PostgresBackendMessageDecoder.swift +++ b/Sources/PostgresNIO/New/PostgresBackendMessageDecoder.swift @@ -189,6 +189,12 @@ struct PSQLPartialDecodingError: Error { description: "Expected the integer to be positive or null, but got \(actual).", file: file, line: line) } + + static func unknownMessageKind(_ messageID: PostgresBackendMessage.ID, file: String = #fileID, line: Int = #line) -> Self { + return PSQLPartialDecodingError( + description: "Unknown message kind: \(messageID)", + file: file, line: line) + } } extension ByteBuffer { diff --git a/Sources/PostgresNIO/New/PostgresChannelHandler.swift b/Sources/PostgresNIO/New/PostgresChannelHandler.swift index 32dea4a5..baf801e5 100644 --- a/Sources/PostgresNIO/New/PostgresChannelHandler.swift +++ b/Sources/PostgresNIO/New/PostgresChannelHandler.swift @@ -20,7 +20,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { private var decoder: NIOSingleStepByteToMessageProcessor private var encoder: PostgresFrontendMessageEncoder! private let configuration: PostgresConnection.InternalConfiguration - private let configureSSLCallback: ((Channel) throws -> Void)? + private let configureSSLCallback: ((Channel, PostgresChannelHandler) throws -> Void)? private var listenState = ListenStateMachine() private var preparedStatementState = PreparedStatementStateMachine() @@ -29,7 +29,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { configuration: PostgresConnection.InternalConfiguration, eventLoop: EventLoop, logger: Logger, - configureSSLCallback: ((Channel) throws -> Void)? + configureSSLCallback: ((Channel, PostgresChannelHandler) throws -> Void)? ) { self.state = ConnectionStateMachine(requireBackendKeyData: configuration.options.requireBackendKeyData) self.eventLoop = eventLoop @@ -46,7 +46,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { eventLoop: EventLoop, state: ConnectionStateMachine = .init(.initialized), logger: Logger = .psqlNoOpLogger, - configureSSLCallback: ((Channel) throws -> Void)? + configureSSLCallback: ((Channel, PostgresChannelHandler) throws -> Void)? ) { self.state = state self.eventLoop = eventLoop @@ -136,6 +136,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler { action = self.state.closeCompletedReceived() case .commandComplete(let commandTag): action = self.state.commandCompletedReceived(commandTag) + case .copyInResponse(let copyInResponse): + action = self.state.copyInResponseReceived(copyInResponse) case .dataRow(let dataRow): action = self.state.dataRowReceived(dataRow) case .emptyQueryResponse: @@ -390,7 +392,8 @@ final class PostgresChannelHandler: ChannelDuplexHandler { let authContext = AuthContext( username: username, password: self.configuration.password, - database: self.configuration.database + database: self.configuration.database, + additionalParameters: self.configuration.options.additionalStartupParameters ) let action = self.state.provideAuthenticationContext(authContext) return self.run(action, with: context) @@ -438,7 +441,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { // This method must only be called, if we signalized the StateMachine before that we are // able to setup a SSL connection. do { - try self.configureSSLCallback!(context.channel) + try self.configureSSLCallback!(context.channel, self) let action = self.state.sslHandlerAdded() self.run(action, with: context) } catch { @@ -549,9 +552,9 @@ final class PostgresChannelHandler: ChannelDuplexHandler { ) self.rowStream = rows - case .noRows(let commandTag): + case .noRows(let summary): rows = PSQLRowStream( - source: .noRows(.success(commandTag)), + source: .noRows(.success(summary)), eventLoop: context.channel.eventLoop, logger: result.logger ) @@ -594,7 +597,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { private func makeStartListeningQuery(channel: String, context: ChannelHandlerContext) -> PSQLTask { let promise = context.eventLoop.makePromise(of: PSQLRowStream.self) let query = ExtendedQueryContext( - query: PostgresQuery(unsafeSQL: "LISTEN \(channel);"), + query: PostgresQuery(unsafeSQL: #"LISTEN "\#(channel)";"#), logger: self.logger, promise: promise ) @@ -642,7 +645,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler { private func makeUnlistenQuery(channel: String, context: ChannelHandlerContext) -> PSQLTask { let promise = context.eventLoop.makePromise(of: PSQLRowStream.self) let query = ExtendedQueryContext( - query: PostgresQuery(unsafeSQL: "UNLISTEN \(channel);"), + query: PostgresQuery(unsafeSQL: #"UNLISTEN "\#(channel)";"#), logger: self.logger, promise: promise ) diff --git a/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift b/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift index 97805418..6ca4cc27 100644 --- a/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift +++ b/Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift @@ -167,6 +167,28 @@ struct PostgresFrontendMessageEncoder { self.buffer.writeMultipleIntegers(UInt32(8), Self.sslRequestCode) } + /// Adds the `CopyData` message ID and `dataLength` to the message buffer but not the actual data. + /// + /// The caller of this function is expected to write the encoder's message buffer to the backend after calling this + /// function, followed by sending the actual data to the backend. + mutating func copyDataHeader(dataLength: UInt32) { + self.clearIfNeeded() + self.buffer.psqlWriteMultipleIntegers(id: .copyData, length: dataLength) + } + + mutating func copyDone() { + self.clearIfNeeded() + self.buffer.psqlWriteMultipleIntegers(id: .copyDone, length: 0) + } + + mutating func copyFail(message: String) { + self.clearIfNeeded() + var messageBuffer = ByteBuffer() + messageBuffer.writeNullTerminatedString(message) + self.buffer.psqlWriteMultipleIntegers(id: .copyFail, length: UInt32(messageBuffer.readableBytes)) + self.buffer.writeImmutableBuffer(messageBuffer) + } + mutating func sync() { self.clearIfNeeded() self.buffer.psqlWriteMultipleIntegers(id: .sync, length: 0) @@ -197,6 +219,9 @@ struct PostgresFrontendMessageEncoder { private enum FrontendMessageID: UInt8, Hashable, Sendable { case bind = 66 // B case close = 67 // C + case copyData = 100 // d + case copyDone = 99 // c + case copyFail = 102 // f case describe = 68 // D case execute = 69 // E case flush = 72 // H diff --git a/Sources/PostgresNIO/New/PostgresNotificationSequence.swift b/Sources/PostgresNIO/New/PostgresNotificationSequence.swift index 55fb0670..d8f525eb 100644 --- a/Sources/PostgresNIO/New/PostgresNotificationSequence.swift +++ b/Sources/PostgresNIO/New/PostgresNotificationSequence.swift @@ -3,7 +3,7 @@ public struct PostgresNotification: Sendable { public let payload: String } -public struct PostgresNotificationSequence: AsyncSequence { +public struct PostgresNotificationSequence: AsyncSequence, Sendable { public typealias Element = PostgresNotification let base: AsyncThrowingStream @@ -21,7 +21,5 @@ public struct PostgresNotificationSequence: AsyncSequence { } } -#if swift(>=5.7) -// AsyncThrowingStream is marked as Sendable in Swift 5.6 -extension PostgresNotificationSequence: Sendable {} -#endif +@available(*, unavailable) +extension PostgresNotificationSequence.AsyncIterator: Sendable {} diff --git a/Sources/PostgresNIO/New/PostgresQuery.swift b/Sources/PostgresNIO/New/PostgresQuery.swift index 1cfcf2dc..6449ab29 100644 --- a/Sources/PostgresNIO/New/PostgresQuery.swift +++ b/Sources/PostgresNIO/New/PostgresQuery.swift @@ -26,7 +26,7 @@ extension PostgresQuery: ExpressibleByStringInterpolation { } extension PostgresQuery { - public struct StringInterpolation: StringInterpolationProtocol { + public struct StringInterpolation: StringInterpolationProtocol, Sendable { public typealias StringLiteralType = String @usableFromInline @@ -172,6 +172,16 @@ public struct PostgresBindings: Sendable, Hashable { try self.append(value, context: .default) } + @inlinable + public mutating func append(_ value: Optional) throws { + switch value { + case .none: + self.appendNull() + case let .some(value): + try self.append(value) + } + } + @inlinable public mutating func append( _ value: Value, @@ -181,11 +191,34 @@ public struct PostgresBindings: Sendable, Hashable { self.metadata.append(.init(value: value, protected: true)) } + @inlinable + public mutating func append( + _ value: Optional, + context: PostgresEncodingContext + ) throws { + switch value { + case .none: + self.appendNull() + case let .some(value): + try self.append(value, context: context) + } + } + @inlinable public mutating func append(_ value: Value) { self.append(value, context: .default) } + @inlinable + public mutating func append(_ value: Optional) { + switch value { + case .none: + self.appendNull() + case let .some(value): + self.append(value) + } + } + @inlinable public mutating func append( _ value: Value, @@ -195,6 +228,19 @@ public struct PostgresBindings: Sendable, Hashable { self.metadata.append(.init(value: value, protected: true)) } + @inlinable + public mutating func append( + _ value: Optional, + context: PostgresEncodingContext + ) { + switch value { + case .none: + self.appendNull() + case let .some(value): + self.append(value, context: context) + } + } + @inlinable mutating func appendUnprotected( _ value: Value, diff --git a/Sources/PostgresNIO/New/PostgresRow-multi-decode.swift b/Sources/PostgresNIO/New/PostgresRow-multi-decode.swift deleted file mode 100644 index 71aa04dc..00000000 --- a/Sources/PostgresNIO/New/PostgresRow-multi-decode.swift +++ /dev/null @@ -1,1175 +0,0 @@ -/// NOTE: THIS FILE IS AUTO-GENERATED BY dev/generate-postgresrow-multi-decode.sh - -#if compiler(<5.9) -extension PostgresRow { - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0) { - precondition(self.columns.count >= 1) - let columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - let column = columnIterator.next().unsafelyUnwrapped - let swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0).Type, file: String = #fileID, line: Int = #line) throws -> (T0) { - try self.decode(T0.self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1) { - precondition(self.columns.count >= 2) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1) { - try self.decode((T0, T1).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2) { - precondition(self.columns.count >= 3) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2) { - try self.decode((T0, T1, T2).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3) { - precondition(self.columns.count >= 4) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3) { - try self.decode((T0, T1, T2, T3).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4) { - precondition(self.columns.count >= 5) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4) { - try self.decode((T0, T1, T2, T3, T4).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5) { - precondition(self.columns.count >= 6) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 5 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T5.self - let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4, r5) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5) { - try self.decode((T0, T1, T2, T3, T4, T5).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6) { - precondition(self.columns.count >= 7) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 5 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T5.self - let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 6 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T6.self - let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4, r5, r6) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6) { - try self.decode((T0, T1, T2, T3, T4, T5, T6).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7) { - precondition(self.columns.count >= 8) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 5 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T5.self - let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 6 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T6.self - let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 7 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T7.self - let r7 = try T7._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4, r5, r6, r7) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7) { - try self.decode((T0, T1, T2, T3, T4, T5, T6, T7).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8) { - precondition(self.columns.count >= 9) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 5 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T5.self - let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 6 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T6.self - let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 7 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T7.self - let r7 = try T7._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 8 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T8.self - let r8 = try T8._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4, r5, r6, r7, r8) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8) { - try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9) { - precondition(self.columns.count >= 10) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 5 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T5.self - let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 6 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T6.self - let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 7 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T7.self - let r7 = try T7._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 8 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T8.self - let r8 = try T8._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 9 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T9.self - let r9 = try T9._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4, r5, r6, r7, r8, r9) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9) { - try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10) { - precondition(self.columns.count >= 11) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 5 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T5.self - let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 6 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T6.self - let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 7 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T7.self - let r7 = try T7._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 8 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T8.self - let r8 = try T8._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 9 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T9.self - let r9 = try T9._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 10 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T10.self - let r10 = try T10._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10) { - try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11) { - precondition(self.columns.count >= 12) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 5 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T5.self - let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 6 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T6.self - let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 7 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T7.self - let r7 = try T7._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 8 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T8.self - let r8 = try T8._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 9 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T9.self - let r9 = try T9._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 10 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T10.self - let r10 = try T10._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 11 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T11.self - let r11 = try T11._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11) { - try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12) { - precondition(self.columns.count >= 13) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 5 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T5.self - let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 6 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T6.self - let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 7 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T7.self - let r7 = try T7._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 8 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T8.self - let r8 = try T8._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 9 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T9.self - let r9 = try T9._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 10 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T10.self - let r10 = try T10._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 11 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T11.self - let r11 = try T11._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 12 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T12.self - let r12 = try T12._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r12) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12) { - try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13) { - precondition(self.columns.count >= 14) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 5 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T5.self - let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 6 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T6.self - let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 7 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T7.self - let r7 = try T7._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 8 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T8.self - let r8 = try T8._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 9 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T9.self - let r9 = try T9._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 10 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T10.self - let r10 = try T10._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 11 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T11.self - let r11 = try T11._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 12 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T12.self - let r12 = try T12._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 13 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T13.self - let r13 = try T13._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r12, r13) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13) { - try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14) { - precondition(self.columns.count >= 15) - var columnIndex = 0 - var cellIterator = self.data.makeIterator() - var cellData = cellIterator.next().unsafelyUnwrapped - var columnIterator = self.columns.makeIterator() - var column = columnIterator.next().unsafelyUnwrapped - var swiftTargetType: Any.Type = T0.self - - do { - let r0 = try T0._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 1 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T1.self - let r1 = try T1._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 2 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T2.self - let r2 = try T2._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 3 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T3.self - let r3 = try T3._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 4 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T4.self - let r4 = try T4._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 5 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T5.self - let r5 = try T5._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 6 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T6.self - let r6 = try T6._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 7 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T7.self - let r7 = try T7._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 8 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T8.self - let r8 = try T8._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 9 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T9.self - let r9 = try T9._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 10 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T10.self - let r10 = try T10._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 11 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T11.self - let r11 = try T11._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 12 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T12.self - let r12 = try T12._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 13 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T13.self - let r13 = try T13._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - columnIndex = 14 - cellData = cellIterator.next().unsafelyUnwrapped - column = columnIterator.next().unsafelyUnwrapped - swiftTargetType = T14.self - let r14 = try T14._decodeRaw(from: &cellData, type: column.dataType, format: column.format, context: context) - - return (r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, r12, r13, r14) - } catch let code as PostgresDecodingError.Code { - throw PostgresDecodingError( - code: code, - columnName: column.name, - columnIndex: columnIndex, - targetType: swiftTargetType, - postgresType: column.dataType, - postgresFormat: column.format, - postgresData: cellData, - file: file, - line: line - ) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).Type, file: String = #fileID, line: Int = #line) throws -> (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14) { - try self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).self, context: .default, file: file, line: line) - } -} -#endif diff --git a/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift b/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift deleted file mode 100644 index f45357d8..00000000 --- a/Sources/PostgresNIO/New/PostgresRowSequence-multi-decode.swift +++ /dev/null @@ -1,215 +0,0 @@ -/// NOTE: THIS FILE IS AUTO-GENERATED BY dev/generate-postgresrowsequence-multi-decode.sh - -#if compiler(<5.9) -extension AsyncSequence where Element == PostgresRow { - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode(T0.self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode(T0.self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4, T5).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4, T5).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4, T5, T6).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4, T5, T6).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4, T5, T6, T7).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4, T5, T6, T7).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13).self, context: .default, file: file, line: line) - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).Type, context: PostgresDecodingContext, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.map { row in - try row.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).self, context: context, file: file, line: line) - } - } - - @inlinable - @_alwaysEmitIntoClient - public func decode(_: (T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).Type, file: String = #fileID, line: Int = #line) -> AsyncThrowingMapSequence { - self.decode((T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14).self, context: .default, file: file, line: line) - } -} -#endif diff --git a/Sources/PostgresNIO/New/PostgresRowSequence.swift b/Sources/PostgresNIO/New/PostgresRowSequence.swift index ccf4f69c..3936b51e 100644 --- a/Sources/PostgresNIO/New/PostgresRowSequence.swift +++ b/Sources/PostgresNIO/New/PostgresRowSequence.swift @@ -4,7 +4,7 @@ import NIOConcurrencyHelpers /// An async sequence of ``PostgresRow``s. /// /// - Note: This is a struct to allow us to move to a move only type easily once they become available. -public struct PostgresRowSequence: AsyncSequence { +public struct PostgresRowSequence: AsyncSequence, Sendable { public typealias Element = PostgresRow typealias BackingSequence = NIOThrowingAsyncSequenceProducer @@ -56,6 +56,9 @@ extension PostgresRowSequence { } } +@available(*, unavailable) +extension PostgresRowSequence.AsyncIterator: Sendable {} + extension PostgresRowSequence { public func collect() async throws -> [PostgresRow] { var result = [PostgresRow]() diff --git a/Sources/PostgresNIO/New/PostgresTransactionError.swift b/Sources/PostgresNIO/New/PostgresTransactionError.swift new file mode 100644 index 00000000..35038446 --- /dev/null +++ b/Sources/PostgresNIO/New/PostgresTransactionError.swift @@ -0,0 +1,21 @@ +/// A wrapper around the errors that can occur during a transaction. +public struct PostgresTransactionError: Error { + + /// The file in which the transaction was started + public var file: String + /// The line in which the transaction was started + public var line: Int + + /// The error thrown when running the `BEGIN` query + public var beginError: Error? + /// The error thrown in the transaction closure + public var closureError: Error? + + /// The error thrown while rolling the transaction back. If the ``closureError`` is set, + /// but the ``rollbackError`` is empty, the rollback was successful. If the ``rollbackError`` + /// is set, the rollback failed. + public var rollbackError: Error? + + /// The error thrown while commiting the transaction. + public var commitError: Error? +} diff --git a/Sources/PostgresNIO/New/VariadicGenerics.swift b/Sources/PostgresNIO/New/VariadicGenerics.swift index 312d36dc..7931c90c 100644 --- a/Sources/PostgresNIO/New/VariadicGenerics.swift +++ b/Sources/PostgresNIO/New/VariadicGenerics.swift @@ -1,4 +1,4 @@ -#if compiler(>=5.9) + extension PostgresRow { // --- snip TODO: Remove once bug is fixed, that disallows tuples of one @inlinable @@ -170,5 +170,3 @@ enum ComputeParameterPackLength { MemoryLayout<(repeat BoolConverter.Bool)>.size / MemoryLayout.stride } } -#endif // compiler(>=5.9) - diff --git a/Sources/PostgresNIO/Pool/ConnectionFactory.swift b/Sources/PostgresNIO/Pool/ConnectionFactory.swift index 77a0c047..31343826 100644 --- a/Sources/PostgresNIO/Pool/ConnectionFactory.swift +++ b/Sources/PostgresNIO/Pool/ConnectionFactory.swift @@ -15,9 +15,9 @@ final class ConnectionFactory: Sendable { struct SSLContextCache: Sendable { enum State { case none - case producing(TLSConfiguration, [CheckedContinuation]) - case cached(TLSConfiguration, NIOSSLContext) - case failed(TLSConfiguration, any Error) + case producing([CheckedContinuation]) + case cached(NIOSSLContext) + case failed(any Error) } var state: State = .none @@ -89,6 +89,7 @@ final class ConnectionFactory: Sendable { connectionConfig.options.connectTimeout = TimeAmount(config.options.connectTimeout) connectionConfig.options.tlsServerName = config.options.tlsServerName connectionConfig.options.requireBackendKeyData = config.options.requireBackendKeyData + connectionConfig.options.additionalStartupParameters = config.options.additionalStartupParameters return connectionConfig } @@ -105,34 +106,17 @@ final class ConnectionFactory: Sendable { let action = self.sslContextBox.withLockedValue { cache -> Action in switch cache.state { case .none: - cache.state = .producing(tlsConfiguration, [continuation]) + cache.state = .producing([continuation]) return .produce - case .cached(let cachedTLSConfiguration, let context): - if cachedTLSConfiguration.bestEffortEquals(tlsConfiguration) { - return .succeed(context) - } else { - cache.state = .producing(tlsConfiguration, [continuation]) - return .produce - } - - case .failed(let cachedTLSConfiguration, let error): - if cachedTLSConfiguration.bestEffortEquals(tlsConfiguration) { - return .fail(error) - } else { - cache.state = .producing(tlsConfiguration, [continuation]) - return .produce - } - - case .producing(let cachedTLSConfiguration, var continuations): + case .cached(let context): + return .succeed(context) + case .failed(let error): + return .fail(error) + case .producing(var continuations): continuations.append(continuation) - if cachedTLSConfiguration.bestEffortEquals(tlsConfiguration) { - cache.state = .producing(cachedTLSConfiguration, continuations) - return .wait - } else { - cache.state = .producing(tlsConfiguration, continuations) - return .produce - } + cache.state = .producing(continuations) + return .wait } } @@ -142,10 +126,7 @@ final class ConnectionFactory: Sendable { case .produce: // TBD: we might want to consider moving this off the concurrent executor - self.reportProduceSSLContextResult( - Result(catching: {try NIOSSLContext(configuration: tlsConfiguration)}), - for: tlsConfiguration - ) + self.reportProduceSSLContextResult(Result(catching: {try NIOSSLContext(configuration: tlsConfiguration)})) case .succeed(let context): continuation.resume(returning: context) @@ -156,7 +137,7 @@ final class ConnectionFactory: Sendable { } } - private func reportProduceSSLContextResult(_ result: Result, for tlsConfiguration: TLSConfiguration) { + private func reportProduceSSLContextResult(_ result: Result) { enum Action { case fail(any Error, [CheckedContinuation]) case succeed(NIOSSLContext, [CheckedContinuation]) @@ -171,19 +152,15 @@ final class ConnectionFactory: Sendable { case .cached, .failed: return .none - case .producing(let cachedTLSConfiguration, let continuations): - if cachedTLSConfiguration.bestEffortEquals(tlsConfiguration) { - switch result { - case .success(let context): - cache.state = .cached(cachedTLSConfiguration, context) - return .succeed(context, continuations) - - case .failure(let failure): - cache.state = .failed(cachedTLSConfiguration, failure) - return .fail(failure, continuations) - } - } else { - return .none + case .producing(let continuations): + switch result { + case .success(let context): + cache.state = .cached(context) + return .succeed(context, continuations) + + case .failure(let failure): + cache.state = .failed(failure) + return .fail(failure, continuations) } } } diff --git a/Sources/PostgresNIO/Pool/PostgresClient.swift b/Sources/PostgresNIO/Pool/PostgresClient.swift index 9383ffcd..0279be07 100644 --- a/Sources/PostgresNIO/Pool/PostgresClient.swift +++ b/Sources/PostgresNIO/Pool/PostgresClient.swift @@ -47,7 +47,7 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { } /// Do not try to create a TLS connection to the server. - public static var disable: Self = Self.init(.disable) + public static let disable: Self = Self.init(.disable) /// Try to create a TLS connection to the server. If the server supports TLS, create a TLS connection. /// If the server does not support TLS, create an insecure connection. @@ -106,6 +106,10 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { /// If you are not using Amazon RDS Proxy, you should leave this set to `true` (the default). public var requireBackendKeyData: Bool = true + /// Additional parameters to send to the server on startup. The name value pairs are added to the initial + /// startup message that the client sends to the server. + public var additionalStartupParameters: [(String, String)] = [] + /// The minimum number of connections that the client shall keep open at any time, even if there is no /// demand. Default to `0`. /// @@ -289,20 +293,94 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { return ConnectionAndMetadata(connection: connection, maximalStreamsOnConnection: 1) } } - /// Lease a connection for the provided `closure`'s lifetime. /// /// - Parameter closure: A closure that uses the passed `PostgresConnection`. The closure **must not** capture /// the provided `PostgresConnection`. /// - Returns: The closure's return value. + @_disfavoredOverload public func withConnection(_ closure: (PostgresConnection) async throws -> Result) async throws -> Result { - let connection = try await self.leaseConnection() + let lease = try await self.leaseConnection() - defer { self.pool.releaseConnection(connection) } + defer { lease.release() } + + return try await closure(lease.connection) + } + + #if compiler(>=6.0) + /// Lease a connection for the provided `closure`'s lifetime. + /// + /// - Parameter closure: A closure that uses the passed `PostgresConnection`. The closure **must not** capture + /// the provided `PostgresConnection`. + /// - Returns: The closure's return value. + public func withConnection( + isolation: isolated (any Actor)? = #isolation, + // DO NOT FIX THE WHITESPACE IN THE NEXT LINE UNTIL 5.10 IS UNSUPPORTED + // https://github.com/swiftlang/swift/issues/79285 + _ closure: (PostgresConnection) async throws -> sending Result) async throws -> sending Result { + let lease = try await self.leaseConnection() + + defer { lease.release() } + + return try await closure(lease.connection) + } - return try await closure(connection) + /// Lease a connection, which is in an open transaction state, for the provided `closure`'s lifetime. + /// + /// The function leases a connection from the underlying connection pool and starts a transaction by running a `BEGIN` + /// query on the leased connection against the database. It then lends the connection to the user provided closure. + /// The user can then modify the database as they wish. If the user provided closure returns successfully, the function + /// will attempt to commit the changes by running a `COMMIT` query against the database. If the user provided closure + /// throws an error, the function will attempt to rollback the changes made within the closure. + /// + /// - Parameters: + /// - logger: The `Logger` to log into for the transaction. + /// - file: The file, the transaction was started in. Used for better error reporting. + /// - line: The line, the transaction was started in. Used for better error reporting. + /// - closure: The user provided code to modify the database. Use the provided connection to run queries. + /// The connection must stay in the transaction mode. Otherwise this method will throw! + /// - Returns: The closure's return value. + public func withTransaction( + logger: Logger, + file: String = #file, + line: Int = #line, + isolation: isolated (any Actor)? = #isolation, + // DO NOT FIX THE WHITESPACE IN THE NEXT LINE UNTIL 5.10 IS UNSUPPORTED + // https://github.com/swiftlang/swift/issues/79285 + _ closure: (PostgresConnection) async throws -> sending Result) async throws -> sending Result { + try await self.withConnection { connection in + try await connection.withTransaction(logger: logger, file: file, line: line, closure) + } + } + #else + + /// Lease a connection, which is in an open transaction state, for the provided `closure`'s lifetime. + /// + /// The function leases a connection from the underlying connection pool and starts a transaction by running a `BEGIN` + /// query on the leased connection against the database. It then lends the connection to the user provided closure. + /// The user can then modify the database as they wish. If the user provided closure returns successfully, the function + /// will attempt to commit the changes by running a `COMMIT` query against the database. If the user provided closure + /// throws an error, the function will attempt to rollback the changes made within the closure. + /// + /// - Parameters: + /// - logger: The `Logger` to log into for the transaction. + /// - file: The file, the transaction was started in. Used for better error reporting. + /// - line: The line, the transaction was started in. Used for better error reporting. + /// - closure: The user provided code to modify the database. Use the provided connection to run queries. + /// The connection must stay in the transaction mode. Otherwise this method will throw! + /// - Returns: The closure's return value. + public func withTransaction( + logger: Logger, + file: String = #file, + line: Int = #line, + _ closure: (PostgresConnection) async throws -> Result + ) async throws -> Result { + try await self.withConnection { connection in + try await connection.withTransaction(logger: logger, file: file, line: line, closure) + } } + #endif /// Run a query on the Postgres server the client is connected to. /// @@ -326,7 +404,8 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { throw PSQLError(code: .tooManyParameters, query: query, file: file, line: line) } - let connection = try await self.leaseConnection() + let lease = try await self.leaseConnection() + let connection = lease.connection var logger = logger logger[postgresMetadataKey: .connectionID] = "\(connection.id)" @@ -341,12 +420,12 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { connection.channel.write(HandlerTask.extendedQuery(context), promise: nil) promise.futureResult.whenFailure { _ in - self.pool.releaseConnection(connection) + lease.release() } return try await promise.futureResult.map { $0.asyncSequence(onFinish: { - self.pool.releaseConnection(connection) + lease.release() }) }.get() } catch var error as PSQLError { @@ -368,7 +447,8 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { let logger = logger ?? Self.loggingDisabled do { - let connection = try await self.leaseConnection() + let lease = try await self.leaseConnection() + let connection = lease.connection let promise = connection.channel.eventLoop.makePromise(of: PSQLRowStream.self) let task = HandlerTask.executePreparedStatement(.init( @@ -382,11 +462,11 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { connection.channel.write(task, promise: nil) promise.futureResult.whenFailure { _ in - self.pool.releaseConnection(connection) + lease.release() } return try await promise.futureResult - .map { $0.asyncSequence(onFinish: { self.pool.releaseConnection(connection) }) } + .map { $0.asyncSequence(onFinish: { lease.release() }) } .get() .map { try preparedStatement.decodeRow($0) } } catch var error as PSQLError { @@ -419,14 +499,14 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service { let atomicOp = self.runningAtomic.compareExchange(expected: false, desired: true, ordering: .relaxed) precondition(!atomicOp.original, "PostgresClient.run() should just be called once!") - await cancelOnGracefulShutdown { + await cancelWhenGracefulShutdown { await self.pool.run() } } // MARK: - Private Methods - - private func leaseConnection() async throws -> PostgresConnection { + private func leaseConnection() async throws -> ConnectionLease { if !self.runningAtomic.load(ordering: .relaxed) { self.backgroundLogger.warning("Trying to lease connection from `PostgresClient`, but `PostgresClient.run()` hasn't been called yet.") } @@ -478,7 +558,7 @@ extension PostgresConnection: PooledConnection { self.channel.close(mode: .all, promise: nil) } - public func onClose(_ closure: @escaping ((any Error)?) -> ()) { + public func onClose(_ closure: @escaping @Sendable ((any Error)?) -> ()) { self.closeFuture.whenComplete { _ in closure(nil) } } } diff --git a/Sources/PostgresNIO/Pool/PostgresClientMetrics.swift b/Sources/PostgresNIO/Pool/PostgresClientMetrics.swift index aa8215db..62fa326a 100644 --- a/Sources/PostgresNIO/Pool/PostgresClientMetrics.swift +++ b/Sources/PostgresNIO/Pool/PostgresClientMetrics.swift @@ -19,7 +19,7 @@ final class PostgresClientMetrics: ConnectionPoolObservabilityDelegate { /// A connection attempt failed with the given error. After some period of /// time ``startedConnecting(id:)`` may be called again. func connectFailed(id: ConnectionID, error: Error) { - self.logger.debug("Connection creation failed", metadata: [ + self.logger.info("Connection creation failed", metadata: [ .connectionID: "\(id)", .error: "\(String(reflecting: error))" ]) diff --git a/Sources/PostgresNIO/PostgresDatabase+Query.swift b/Sources/PostgresNIO/PostgresDatabase+Query.swift index 01a7e61f..8de93814 100644 --- a/Sources/PostgresNIO/PostgresDatabase+Query.swift +++ b/Sources/PostgresNIO/PostgresDatabase+Query.swift @@ -40,7 +40,7 @@ extension PostgresDatabase { } } -public struct PostgresQueryResult { +public struct PostgresQueryResult: Sendable { public let metadata: PostgresQueryMetadata public let rows: [PostgresRow] } @@ -73,10 +73,7 @@ public struct PostgresQueryMetadata: Sendable { init?(string: String) { let parts = string.split(separator: " ") - guard parts.count >= 1 else { - return nil - } - switch parts[0] { + switch parts.first { case "INSERT": // INSERT oid rows guard parts.count == 3 else { diff --git a/Sources/PostgresNIO/Utilities/Exports.swift b/Sources/PostgresNIO/Utilities/Exports.swift index 58e12891..144ff3c9 100644 --- a/Sources/PostgresNIO/Utilities/Exports.swift +++ b/Sources/PostgresNIO/Utilities/Exports.swift @@ -1,14 +1,3 @@ -#if swift(>=5.8) - @_documentation(visibility: internal) @_exported import NIO @_documentation(visibility: internal) @_exported import NIOSSL @_documentation(visibility: internal) @_exported import struct Logging.Logger - -#else - -// TODO: Remove this with the next major release! -@_exported import NIO -@_exported import NIOSSL -@_exported import struct Logging.Logger - -#endif diff --git a/Sources/PostgresNIO/Utilities/PostgresError+Code.swift b/Sources/PostgresNIO/Utilities/PostgresError+Code.swift index 11224f4b..fae903fe 100644 --- a/Sources/PostgresNIO/Utilities/PostgresError+Code.swift +++ b/Sources/PostgresNIO/Utilities/PostgresError+Code.swift @@ -1,5 +1,5 @@ extension PostgresError { - public struct Code: ExpressibleByStringLiteral, Equatable { + public struct Code: Sendable, ExpressibleByStringLiteral, Equatable { // Class 00 — Successful Completion public static let successfulCompletion: Code = "00000" diff --git a/Sources/PostgresNIO/Utilities/SASLAuthentication+SCRAM-SHA256.swift b/Sources/PostgresNIO/Utilities/SASLAuthentication+SCRAM-SHA256.swift index 2a717b6b..53e9c3f7 100644 --- a/Sources/PostgresNIO/Utilities/SASLAuthentication+SCRAM-SHA256.swift +++ b/Sources/PostgresNIO/Utilities/SASLAuthentication+SCRAM-SHA256.swift @@ -1,4 +1,5 @@ import Crypto +import _CryptoExtras import Foundation extension UInt8 { @@ -292,7 +293,7 @@ internal struct SHA256: SASLAuthenticationMechanism { /// authenticating user. If the closure throws, authentication /// immediately fails with the thrown error. internal init(username: String, password: @escaping () throws -> String) { - self._impl = .init(username: username, passwordGrabber: { _ in try (Array(password().data(using: .utf8)!), []) }, bindingInfo: .unsupported) + self._impl = .init(username: username, passwordGrabber: { _ in try (Array(password().utf8), []) }, bindingInfo: .unsupported) } /// Set up a server-side `SCRAM-SHA-256` authentication. @@ -338,7 +339,7 @@ internal struct SHA256_PLUS: SASLAuthenticationMechanism { /// - channelBindingData: The appropriate data associated with the RFC5056 /// channel binding specified. internal init(username: String, password: @escaping () throws -> String, channelBindingName: String, channelBindingData: [UInt8]) { - self._impl = .init(username: username, passwordGrabber: { _ in try (Array(password().data(using: .utf8)!), []) }, bindingInfo: .bind(channelBindingName, channelBindingData)) + self._impl = .init(username: username, passwordGrabber: { _ in try (Array(password().utf8), []) }, bindingInfo: .bind(channelBindingName, channelBindingData)) } /// Set up a server-side `SCRAM-SHA-256` authentication. @@ -466,8 +467,8 @@ fileprivate final class SASLMechanism_SCRAM_SHA256_Common { // TODO: Perform `Normalize(password)`, aka the SASLprep profile (RFC4013) of stringprep (RFC3454) // Calculate `AuthMessage`, `ClientSignature`, and `ClientProof` - let saltedPassword = Hi(string: password, salt: serverSalt, iterations: serverIterations) - let clientKey = HMAC.authenticationCode(for: "Client Key".data(using: .utf8)!, using: .init(data: saltedPassword)) + let saltedPassword = try Hi(string: password, salt: serverSalt, iterations: serverIterations) + let clientKey = HMAC.authenticationCode(for: Data("Client Key".utf8), using: saltedPassword) let storedKey = SHA256.hash(data: Data(clientKey)) var authMessage = firstMessageBare; authMessage.append(.comma); authMessage.append(contentsOf: message); authMessage.append(.comma); authMessage.append(contentsOf: clientFinalNoProof) let clientSignature = HMAC.authenticationCode(for: authMessage, using: .init(data: storedKey)) @@ -485,9 +486,11 @@ fileprivate final class SASLMechanism_SCRAM_SHA256_Common { var clientFinalMessage = clientFinalNoProof; clientFinalMessage.append(.comma) guard let proofPart = SCRAMMessageParser.serialize([.p(Array(clientProof))]) else { throw SASLAuthenticationError.genericAuthenticationFailure } clientFinalMessage.append(contentsOf: proofPart) - + + let saltedPasswordBytes = saltedPassword.withUnsafeBytes { [UInt8]($0) } + // Save state and send - self.state = .clientSentFinalMessage(saltedPassword: saltedPassword, authMessage: authMessage) + self.state = .clientSentFinalMessage(saltedPassword: saltedPasswordBytes, authMessage: authMessage) return .continue(response: clientFinalMessage) } @@ -501,7 +504,7 @@ fileprivate final class SASLMechanism_SCRAM_SHA256_Common { switch incomingAttributes.first { case .v(let verifier): // Verify server signature - let serverKey = HMAC.authenticationCode(for: "Server Key".data(using: .utf8)!, using: .init(data: saltedPassword)) + let serverKey = HMAC.authenticationCode(for: Data("Server Key".utf8), using: .init(data: saltedPassword)) let serverSignature = HMAC.authenticationCode(for: authMessage, using: .init(data: serverKey)) guard Array(serverSignature) == verifier else { @@ -585,7 +588,7 @@ fileprivate final class SASLMechanism_SCRAM_SHA256_Common { guard nonce == repeatNonce else { throw SASLAuthenticationError.genericAuthenticationFailure } // Compute client signature - let clientKey = HMAC.authenticationCode(for: "Client Key".data(using: .utf8)!, using: .init(data: saltedPassword)) + let clientKey = HMAC.authenticationCode(for: Data("Client Key".utf8), using: .init(data: saltedPassword)) let storedKey = SHA256.hash(data: Data(clientKey)) var authMessage = clientBareFirstMessage; authMessage.append(.comma); authMessage.append(contentsOf: serverFirstMessage); authMessage.append(.comma); authMessage.append(contentsOf: message.dropLast(proof.count + 3)) let clientSignature = HMAC.authenticationCode(for: authMessage, using: .init(data: storedKey)) @@ -604,7 +607,7 @@ fileprivate final class SASLMechanism_SCRAM_SHA256_Common { guard storedKey == restoredKey else { throw SCRAMServerError.invalidProof } // Compute server signature - let serverKey = HMAC.authenticationCode(for: "Server Key".data(using: .utf8)!, using: .init(data: saltedPassword)) + let serverKey = HMAC.authenticationCode(for: Data("Server Key".utf8), using: .init(data: saltedPassword)) let serverSignature = HMAC.authenticationCode(for: authMessage, using: .init(data: serverKey)) // Generate a `server-final-message` @@ -640,19 +643,12 @@ fileprivate final class SASLMechanism_SCRAM_SHA256_Common { HMAC() == output length of H(). ```` */ -private func Hi(string: [UInt8], salt: [UInt8], iterations: UInt32) -> [UInt8] { - let key = SymmetricKey(data: string) - var Ui = HMAC.authenticationCode(for: salt + [0x00, 0x00, 0x00, 0x01], using: key) // salt + 0x00000001 as big-endian - var Hi = Array(Ui) - - Hi.withUnsafeMutableBytes { Hibuf -> Void in - for _ in 2...iterations { - Ui = HMAC.authenticationCode(for: Data(Ui), using: key) - - Ui.withUnsafeBytes { Uibuf -> Void in - for i in 0.. SymmetricKey { + try KDF.Insecure.PBKDF2.deriveKey( + from: string, + salt: salt, + using: .sha256, + outputByteCount: 32, + unsafeUncheckedRounds: Int(iterations) + ) } diff --git a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift index 3e3c9d65..c1ba89cb 100644 --- a/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift +++ b/Tests/ConnectionPoolModuleTests/ConnectionPoolTests.swift @@ -1,7 +1,8 @@ @testable import _ConnectionPoolModule +import _ConnectionPoolTestUtils import Atomics -import XCTest import NIOEmbedded +import XCTest @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) final class ConnectionPoolTests: XCTestCase { @@ -26,7 +27,7 @@ final class ConnectionPoolTests: XCTestCase { // the same connection is reused 1000 times await withTaskGroup(of: Void.self) { taskGroup in - taskGroup.addTask { + taskGroup.addTask_ { await pool.run() } @@ -38,15 +39,13 @@ final class ConnectionPoolTests: XCTestCase { do { for _ in 0..<1000 { async let connectionFuture = try await pool.leaseConnection() - var leasedConnection: MockConnection? + var connectionLease: ConnectionLease? XCTAssertEqual(factory.pendingConnectionAttemptsCount, 0) - leasedConnection = try await connectionFuture - XCTAssertNotNil(leasedConnection) - XCTAssert(createdConnection === leasedConnection) + connectionLease = try await connectionFuture + XCTAssertNotNil(connectionLease) + XCTAssert(createdConnection === connectionLease?.connection) - if let leasedConnection { - pool.releaseConnection(leasedConnection) - } + connectionLease?.release() } } catch { XCTFail("Unexpected error: \(error)") @@ -82,14 +81,14 @@ final class ConnectionPoolTests: XCTestCase { } await withTaskGroup(of: Void.self) { taskGroup in - taskGroup.addTask { + taskGroup.addTask_ { await pool.run() } let (blockCancelStream, blockCancelContinuation) = AsyncStream.makeStream(of: Void.self) let (blockConnCreationStream, blockConnCreationContinuation) = AsyncStream.makeStream(of: Void.self) - taskGroup.addTask { + taskGroup.addTask_ { _ = try? await factory.nextConnectAttempt { _ in blockCancelContinuation.yield() var iterator = blockConnCreationStream.makeAsyncIterator() @@ -127,7 +126,7 @@ final class ConnectionPoolTests: XCTestCase { } await withTaskGroup(of: Void.self) { taskGroup in - taskGroup.addTask { + taskGroup.addTask_ { await pool.run() } @@ -170,12 +169,12 @@ final class ConnectionPoolTests: XCTestCase { // the same connection is reused 1000 times await withTaskGroup(of: Void.self) { taskGroup in - taskGroup.addTask { + taskGroup.addTask_ { await pool.run() XCTAssertFalse(hasFinished.compareExchange(expected: false, desired: true, ordering: .relaxed).original) } - taskGroup.addTask { + taskGroup.addTask_ { var usedConnectionIDs = Set() for _ in 0..]() for request in requests { let connection = try await request.future.success - connections.append(connection) + connectionLeases.append(connection) } // Ensure that we got 4 distinct connections - XCTAssertEqual(Set(connections.lazy.map(\.id)).count, 4) + XCTAssertEqual(Set(connectionLeases.lazy.map(\.connection.id)).count, 4) // release all 4 leased connections - for connection in connections { - pool.releaseConnection(connection) + for lease in connectionLeases { + lease.release() } // shutdown @@ -726,7 +725,7 @@ final class ConnectionPoolTests: XCTestCase { // create 4 connection requests let requests = (0..<10).map { ConnectionFuture(id: $0) } pool.leaseConnections(requests) - var connections = [MockConnection]() + var connectionLeases = [ConnectionLease]() await factory.nextConnectAttempt { connectionID in return 10 @@ -734,15 +733,15 @@ final class ConnectionPoolTests: XCTestCase { for request in requests { let connection = try await request.future.success - connections.append(connection) + connectionLeases.append(connection) } // Ensure that all requests got the same connection - XCTAssertEqual(Set(connections.lazy.map(\.id)).count, 1) + XCTAssertEqual(Set(connectionLeases.lazy.map(\.connection.id)).count, 1) // release all 10 leased streams - for connection in connections { - pool.releaseConnection(connection) + for lease in connectionLeases { + lease.release() } for _ in 0..<9 { @@ -791,41 +790,41 @@ final class ConnectionPoolTests: XCTestCase { // create 4 connection requests var requests = (0..<21).map { ConnectionFuture(id: $0) } pool.leaseConnections(requests) - var connections = [MockConnection]() + var connectionLease = [ConnectionLease]() await factory.nextConnectAttempt { connectionID in return 1 } - let connection = try await requests.first!.future.success - connections.append(connection) + let lease = try await requests.first!.future.success + connectionLease.append(lease) requests.removeFirst() - pool.connectionReceivedNewMaxStreamSetting(connection, newMaxStreamSetting: 21) + pool.connectionReceivedNewMaxStreamSetting(lease.connection, newMaxStreamSetting: 21) for (_, request) in requests.enumerated() { let connection = try await request.future.success - connections.append(connection) + connectionLease.append(connection) } // Ensure that all requests got the same connection - XCTAssertEqual(Set(connections.lazy.map(\.id)).count, 1) + XCTAssertEqual(Set(connectionLease.lazy.map(\.connection.id)).count, 1) requests = (22..<42).map { ConnectionFuture(id: $0) } pool.leaseConnections(requests) // release all 21 leased streams in a single call - pool.releaseConnection(connection, streams: 21) + pool.releaseConnection(lease.connection, streams: 21) // ensure all 20 new requests got fulfilled for request in requests { let connection = try await request.future.success - connections.append(connection) + connectionLease.append(connection) } // release all 20 leased streams one by one for _ in requests { - pool.releaseConnection(connection, streams: 1) + pool.releaseConnection(lease.connection, streams: 1) } // shutdown @@ -839,14 +838,14 @@ final class ConnectionPoolTests: XCTestCase { struct ConnectionFuture: ConnectionRequestProtocol { let id: Int - let future: Future + let future: Future> init(id: Int) { self.id = id - self.future = Future(of: MockConnection.self) + self.future = Future(of: ConnectionLease.self) } - func complete(with result: Result) { + func complete(with result: Result, ConnectionPoolError>) { switch result { case .success(let success): self.future.yield(value: success) diff --git a/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift b/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift index 5845267f..2952bf8b 100644 --- a/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift +++ b/Tests/ConnectionPoolModuleTests/ConnectionRequestTests.swift @@ -1,17 +1,19 @@ @testable import _ConnectionPoolModule +import _ConnectionPoolTestUtils import XCTest final class ConnectionRequestTests: XCTestCase { func testHappyPath() async throws { let mockConnection = MockConnection(id: 1) - let connection = try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + let lease = try await withCheckedThrowingContinuation { (continuation: CheckedContinuation, any Error>) in let request = ConnectionRequest(id: 42, continuation: continuation) XCTAssertEqual(request.id, 42) - continuation.resume(with: .success(mockConnection)) + let lease = ConnectionLease(connection: mockConnection) { _ in } + continuation.resume(with: .success(lease)) } - XCTAssert(connection === mockConnection) + XCTAssert(lease.connection === mockConnection) } func testSadPath() async throws { diff --git a/Tests/ConnectionPoolModuleTests/Mocks/MockRequest.swift b/Tests/ConnectionPoolModuleTests/Mocks/MockRequest.swift deleted file mode 100644 index 6aaa9c91..00000000 --- a/Tests/ConnectionPoolModuleTests/Mocks/MockRequest.swift +++ /dev/null @@ -1,28 +0,0 @@ -import _ConnectionPoolModule - -final class MockRequest: ConnectionRequestProtocol, Hashable, Sendable { - typealias Connection = MockConnection - - struct ID: Hashable { - var objectID: ObjectIdentifier - - init(_ request: MockRequest) { - self.objectID = ObjectIdentifier(request) - } - } - - var id: ID { ID(self) } - - - static func ==(lhs: MockRequest, rhs: MockRequest) -> Bool { - lhs.id == rhs.id - } - - func hash(into hasher: inout Hasher) { - hasher.combine(self.id) - } - - func complete(with: Result) { - - } -} diff --git a/Tests/ConnectionPoolModuleTests/NoKeepAliveBehaviorTests.swift b/Tests/ConnectionPoolModuleTests/NoKeepAliveBehaviorTests.swift index b817ce19..b1b54023 100644 --- a/Tests/ConnectionPoolModuleTests/NoKeepAliveBehaviorTests.swift +++ b/Tests/ConnectionPoolModuleTests/NoKeepAliveBehaviorTests.swift @@ -1,4 +1,5 @@ import _ConnectionPoolModule +import _ConnectionPoolTestUtils import XCTest @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift index 6b8d6c6e..b09bfcb4 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionGroupTests.swift @@ -1,5 +1,6 @@ -import XCTest @testable import _ConnectionPoolModule +import _ConnectionPoolTestUtils +import XCTest @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) final class PoolStateMachine_ConnectionGroupTests: XCTestCase { diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift index bc4c2c4b..7dd2b726 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+ConnectionStateTests.swift @@ -1,4 +1,5 @@ @testable import _ConnectionPoolModule +import _ConnectionPoolTestUtils import XCTest @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift index 0231da51..ddd6a71e 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachine+RequestQueueTests.swift @@ -1,4 +1,5 @@ @testable import _ConnectionPoolModule +import _ConnectionPoolTestUtils import XCTest @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) @@ -10,7 +11,7 @@ final class PoolStateMachine_RequestQueueTests: XCTestCase { var queue = TestQueue() XCTAssert(queue.isEmpty) - let request1 = MockRequest() + let request1 = MockRequest(connectionType: MockConnection.self) queue.queue(request1) XCTAssertEqual(queue.count, 1) XCTAssertFalse(queue.isEmpty) @@ -24,11 +25,11 @@ final class PoolStateMachine_RequestQueueTests: XCTestCase { var queue = TestQueue() XCTAssert(queue.isEmpty) - var request1 = MockRequest() + var request1 = MockRequest(connectionType: MockConnection.self) queue.queue(request1) - var request2 = MockRequest() + var request2 = MockRequest(connectionType: MockConnection.self) queue.queue(request2) - var request3 = MockRequest() + var request3 = MockRequest(connectionType: MockConnection.self) queue.queue(request3) do { @@ -48,11 +49,11 @@ final class PoolStateMachine_RequestQueueTests: XCTestCase { var queue = TestQueue() XCTAssert(queue.isEmpty) - var request1 = MockRequest() + var request1 = MockRequest(connectionType: MockConnection.self) queue.queue(request1) - var request2 = MockRequest() + var request2 = MockRequest(connectionType: MockConnection.self) queue.queue(request2) - var request3 = MockRequest() + var request3 = MockRequest(connectionType: MockConnection.self) queue.queue(request3) do { @@ -75,11 +76,11 @@ final class PoolStateMachine_RequestQueueTests: XCTestCase { var queue = TestQueue() XCTAssert(queue.isEmpty) - var request1 = MockRequest() + var request1 = MockRequest(connectionType: MockConnection.self) queue.queue(request1) - var request2 = MockRequest() + var request2 = MockRequest(connectionType: MockConnection.self) queue.queue(request2) - var request3 = MockRequest() + var request3 = MockRequest(connectionType: MockConnection.self) queue.queue(request3) do { @@ -112,11 +113,11 @@ final class PoolStateMachine_RequestQueueTests: XCTestCase { var queue = TestQueue() XCTAssert(queue.isEmpty) - var request1 = MockRequest() + var request1 = MockRequest(connectionType: MockConnection.self) queue.queue(request1) - var request2 = MockRequest() + var request2 = MockRequest(connectionType: MockConnection.self) queue.queue(request2) - var request3 = MockRequest() + var request3 = MockRequest(connectionType: MockConnection.self) queue.queue(request3) do { diff --git a/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift b/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift index f5ada14f..08afdf8e 100644 --- a/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift +++ b/Tests/ConnectionPoolModuleTests/PoolStateMachineTests.swift @@ -1,13 +1,14 @@ -import XCTest @testable import _ConnectionPoolModule +import _ConnectionPoolTestUtils +import XCTest @available(macOS 13.0, iOS 16.0, tvOS 16.0, watchOS 9.0, *) typealias TestPoolStateMachine = PoolStateMachine< MockConnection, ConnectionIDGenerator, MockConnection.ID, - MockRequest, - MockRequest.ID, + MockRequest, + MockRequest.ID, MockTimerCancellationToken > @@ -74,7 +75,7 @@ final class PoolStateMachineTests: XCTestCase { XCTAssertEqual(createdAction1.connection, .scheduleTimers([])) // lease connection 1 - let request1 = MockRequest() + let request1 = MockRequest(connectionType: MockConnection.self) let leaseRequest1 = stateMachine.leaseConnection(request1) XCTAssertEqual(leaseRequest1.connection, .cancelTimers([])) XCTAssertEqual(leaseRequest1.request, .leaseConnection(.init(element: request1), connection1)) @@ -83,13 +84,13 @@ final class PoolStateMachineTests: XCTestCase { XCTAssertEqual(stateMachine.releaseConnection(connection1, streams: 1), .none()) // lease connection 1 - let request2 = MockRequest() + let request2 = MockRequest(connectionType: MockConnection.self) let leaseRequest2 = stateMachine.leaseConnection(request2) XCTAssertEqual(leaseRequest2.connection, .cancelTimers([])) XCTAssertEqual(leaseRequest2.request, .leaseConnection(.init(element: request2), connection1)) // request connection while none is available - let request3 = MockRequest() + let request3 = MockRequest(connectionType: MockConnection.self) let leaseRequest3 = stateMachine.leaseConnection(request3) XCTAssertEqual(leaseRequest3.connection, .makeConnection(.init(connectionID: 1), [])) XCTAssertEqual(leaseRequest3.request, .none) @@ -131,7 +132,7 @@ final class PoolStateMachineTests: XCTestCase { XCTAssertEqual(requests.count, 0) // request connection while none exists - let request1 = MockRequest() + let request1 = MockRequest(connectionType: MockConnection.self) let leaseRequest1 = stateMachine.leaseConnection(request1) XCTAssertEqual(leaseRequest1.connection, .makeConnection(.init(connectionID: 0), [])) XCTAssertEqual(leaseRequest1.request, .none) @@ -143,7 +144,7 @@ final class PoolStateMachineTests: XCTestCase { XCTAssertEqual(createdAction1.connection, .none) // request connection while none is available - let request2 = MockRequest() + let request2 = MockRequest(connectionType: MockConnection.self) let leaseRequest2 = stateMachine.leaseConnection(request2) XCTAssertEqual(leaseRequest2.connection, .makeConnection(.init(connectionID: 1), [])) XCTAssertEqual(leaseRequest2.request, .none) @@ -194,13 +195,13 @@ final class PoolStateMachineTests: XCTestCase { XCTAssertEqual(createdAction1.connection, .scheduleTimers([])) // lease connection 1 - let request1 = MockRequest() + let request1 = MockRequest(connectionType: MockConnection.self) let leaseRequest1 = stateMachine.leaseConnection(request1) XCTAssertEqual(leaseRequest1.connection, .cancelTimers([])) XCTAssertEqual(leaseRequest1.request, .leaseConnection(.init(element: request1), connection1)) // request connection while none is available - let request2 = MockRequest() + let request2 = MockRequest(connectionType: MockConnection.self) let leaseRequest2 = stateMachine.leaseConnection(request2) XCTAssertEqual(leaseRequest2.connection, .makeConnection(.init(connectionID: 1), [])) XCTAssertEqual(leaseRequest2.request, .none) @@ -244,7 +245,7 @@ final class PoolStateMachineTests: XCTestCase { XCTAssertEqual(requests.count, 0) // request connection while none exists - let request1 = MockRequest() + let request1 = MockRequest(connectionType: MockConnection.self) let leaseRequest1 = stateMachine.leaseConnection(request1) XCTAssertEqual(leaseRequest1.connection, .makeConnection(.init(connectionID: 0), [])) XCTAssertEqual(leaseRequest1.request, .none) @@ -286,7 +287,7 @@ final class PoolStateMachineTests: XCTestCase { XCTAssertEqual(requests.count, 0) // request connection while none exists - let request1 = MockRequest() + let request1 = MockRequest(connectionType: MockConnection.self) let leaseRequest1 = stateMachine.leaseConnection(request1) XCTAssertEqual(leaseRequest1.connection, .makeConnection(.init(connectionID: 0), [])) XCTAssertEqual(leaseRequest1.request, .none) @@ -308,7 +309,7 @@ final class PoolStateMachineTests: XCTestCase { connection1.closeIfClosing() // request connection while none exists anymore - let request2 = MockRequest() + let request2 = MockRequest(connectionType: MockConnection.self) let leaseRequest2 = stateMachine.leaseConnection(request2) XCTAssertEqual(leaseRequest2.connection, .makeConnection(.init(connectionID: 1), [])) XCTAssertEqual(leaseRequest2.request, .none) @@ -353,7 +354,7 @@ final class PoolStateMachineTests: XCTestCase { XCTAssertEqual(requests.count, 1) // one connection should exist - let request = MockRequest() + let request = MockRequest(connectionType: MockConnection.self) let leaseRequest = stateMachine.leaseConnection(request) XCTAssertEqual(leaseRequest.connection, .none) XCTAssertEqual(leaseRequest.request, .none) @@ -375,6 +376,10 @@ final class PoolStateMachineTests: XCTestCase { let connectionClosed = stateMachine.connectionClosed(connection) XCTAssertEqual(connectionClosed.connection, .makeConnection(.init(connectionID: 1), [])) connection.closeIfClosing() + let establishAction = stateMachine.connectionEstablished(.init(id: 1), maxStreams: 1) + XCTAssertEqual(establishAction.request, .none) + guard case .scheduleTimers(let timers) = establishAction.connection else { return XCTFail("Unexpected connection action") } + XCTAssertEqual(timers, [.init(.init(timerID: 0, connectionID: 1, usecase: .keepAlive), duration: configuration.keepAliveDuration!)]) } } diff --git a/Tests/ConnectionPoolModuleTests/TinyFastSequenceTests.swift b/Tests/ConnectionPoolModuleTests/TinyFastSequenceTests.swift index 1a2836b9..6be22005 100644 --- a/Tests/ConnectionPoolModuleTests/TinyFastSequenceTests.swift +++ b/Tests/ConnectionPoolModuleTests/TinyFastSequenceTests.swift @@ -34,7 +34,7 @@ final class TinyFastSequenceTests: XCTestCase { guard case .n(let array) = emptySequence.base else { return XCTFail("Expected sequence to be backed by an array") } - XCTAssertEqual(array.capacity, 8) + XCTAssertGreaterThanOrEqual(array.capacity, 8) var oneElemSequence = TinyFastSequence(element: 1) oneElemSequence.reserveCapacity(8) @@ -43,14 +43,22 @@ final class TinyFastSequenceTests: XCTestCase { guard case .n(let array) = oneElemSequence.base else { return XCTFail("Expected sequence to be backed by an array") } - XCTAssertEqual(array.capacity, 8) + XCTAssertGreaterThanOrEqual(array.capacity, 8) var twoElemSequence = TinyFastSequence([1, 2]) twoElemSequence.reserveCapacity(8) + twoElemSequence.append(3) guard case .n(let array) = twoElemSequence.base else { return XCTFail("Expected sequence to be backed by an array") } - XCTAssertEqual(array.capacity, 8) + XCTAssertGreaterThanOrEqual(array.capacity, 8) + + var threeElemSequence = TinyFastSequence([1, 2, 3]) + threeElemSequence.reserveCapacity(8) + guard case .n(let array) = threeElemSequence.base else { + return XCTFail("Expected sequence to be backed by an array") + } + XCTAssertGreaterThanOrEqual(array.capacity, 8) } func testNewSequenceSlowPath() { diff --git a/Tests/IntegrationTests/AsyncTests.swift b/Tests/IntegrationTests/AsyncTests.swift index 75e5b6ba..b4c8e93f 100644 --- a/Tests/IntegrationTests/AsyncTests.swift +++ b/Tests/IntegrationTests/AsyncTests.swift @@ -84,6 +84,36 @@ final class AsyncPostgresConnectionTests: XCTestCase { } } + func testAdditionalParametersTakeEffect() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + let query: PostgresQuery = """ + SELECT + current_setting('application_name'); + """ + + let applicationName = "postgres-nio-test" + var options = PostgresConnection.Configuration.Options() + options.additionalStartupParameters = [ + ("application_name", applicationName) + ] + + try await withTestConnection(on: eventLoop, options: options) { connection in + let rows = try await connection.query(query, logger: .psqlTest) + var counter = 0 + + for try await element in rows.decode(String.self) { + XCTAssertEqual(element, applicationName) + + counter += 1 + } + + XCTAssertGreaterThanOrEqual(counter, 1) + } + } + func testSelectTimeoutWhileLongRunningQuery() async throws { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } @@ -225,25 +255,32 @@ final class AsyncPostgresConnectionTests: XCTestCase { } func testListenAndNotify() async throws { + let channelNames = [ + "foo", + "default" + ] + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } let eventLoop = eventLoopGroup.next() - try await self.withTestConnection(on: eventLoop) { connection in - let stream = try await connection.listen("foo") - var iterator = stream.makeAsyncIterator() + for channelName in channelNames { + try await self.withTestConnection(on: eventLoop) { connection in + let stream = try await connection.listen(channelName) + var iterator = stream.makeAsyncIterator() - try await self.withTestConnection(on: eventLoop) { other in - try await other.query(#"NOTIFY foo, 'bar';"#, logger: .psqlTest) + try await self.withTestConnection(on: eventLoop) { other in + try await other.query(#"NOTIFY "\#(unescaped: channelName)", 'bar';"#, logger: .psqlTest) - try await other.query(#"NOTIFY foo, 'foo';"#, logger: .psqlTest) - } + try await other.query(#"NOTIFY "\#(unescaped: channelName)", 'foo';"#, logger: .psqlTest) + } - let first = try await iterator.next() - XCTAssertEqual(first?.payload, "bar") + let first = try await iterator.next() + XCTAssertEqual(first?.payload, "bar") - let second = try await iterator.next() - XCTAssertEqual(second?.payload, "foo") + let second = try await iterator.next() + XCTAssertEqual(second?.payload, "foo") + } } } @@ -439,17 +476,99 @@ final class AsyncPostgresConnectionTests: XCTestCase { XCTFail("Unexpected error: \(String(describing: error))") } } + + static let preparedStatementWithOptionalTestTable = "AsyncTestPreparedStatementWithOptionalTestTable" + func testPreparedStatementWithOptionalBinding() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + struct InsertPreparedStatement: PostgresPreparedStatement { + static let name = "INSERT-AsyncTestPreparedStatementWithOptionalTestTable" + + static let sql = #"INSERT INTO "\#(AsyncPostgresConnectionTests.preparedStatementWithOptionalTestTable)" (uuid) VALUES ($1);"# + typealias Row = () + + var uuid: UUID? + + func makeBindings() -> PostgresBindings { + var bindings = PostgresBindings() + bindings.append(self.uuid) + return bindings + } + + func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { + () + } + } + + struct SelectPreparedStatement: PostgresPreparedStatement { + static let name = "SELECT-AsyncTestPreparedStatementWithOptionalTestTable" + + static let sql = #"SELECT id, uuid FROM "\#(AsyncPostgresConnectionTests.preparedStatementWithOptionalTestTable)" WHERE id <= $1;"# + typealias Row = (Int, UUID?) + + var id: Int + + func makeBindings() -> PostgresBindings { + var bindings = PostgresBindings() + bindings.append(self.id) + return bindings + } + + func decodeRow(_ row: PostgresNIO.PostgresRow) throws -> Row { + try row.decode((Int, UUID?).self) + } + } + + do { + try await withTestConnection(on: eventLoop) { connection in + try await connection.query(""" + CREATE TABLE IF NOT EXISTS "\(unescaped: Self.preparedStatementWithOptionalTestTable)" ( + id SERIAL PRIMARY KEY, + uuid UUID + ) + """, + logger: .psqlTest + ) + + _ = try await connection.execute(InsertPreparedStatement(uuid: nil), logger: .psqlTest) + _ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest) + _ = try await connection.execute(InsertPreparedStatement(uuid: nil), logger: .psqlTest) + _ = try await connection.execute(InsertPreparedStatement(uuid: .init()), logger: .psqlTest) + _ = try await connection.execute(InsertPreparedStatement(uuid: nil), logger: .psqlTest) + + let rows = try await connection.execute(SelectPreparedStatement(id: 3), logger: .psqlTest) + var counter = 0 + for try await (id, uuid) in rows { + Logger.psqlTest.info("Received row", metadata: [ + "id": "\(id)", "uuid": "\(String(describing: uuid))" + ]) + counter += 1 + } + + try await connection.query(""" + DROP TABLE "\(unescaped: Self.preparedStatementWithOptionalTestTable)"; + """, + logger: .psqlTest + ) + } + } catch { + XCTFail("Unexpected error: \(String(describing: error))") + } + } } extension XCTestCase { func withTestConnection( on eventLoop: EventLoop, + options: PostgresConnection.Configuration.Options? = nil, file: StaticString = #filePath, line: UInt = #line, _ closure: (PostgresConnection) async throws -> Result ) async throws -> Result { - let connection = try await PostgresConnection.test(on: eventLoop).get() + let connection = try await PostgresConnection.test(on: eventLoop, options: options).get() do { let result = try await closure(connection) diff --git a/Tests/IntegrationTests/PSQLIntegrationTests.swift b/Tests/IntegrationTests/PSQLIntegrationTests.swift index 57939c06..d541899b 100644 --- a/Tests/IntegrationTests/PSQLIntegrationTests.swift +++ b/Tests/IntegrationTests/PSQLIntegrationTests.swift @@ -123,6 +123,25 @@ final class IntegrationTests: XCTestCase { XCTAssertEqual(foo, "hello") } + func testQueryNothing() throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let eventLoop = eventLoopGroup.next() + + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow(try conn?.close().wait()) } + + var _result: PostgresQueryResult? + XCTAssertNoThrow(_result = try conn?.query(""" + -- Some comments + """, logger: .psqlTest).wait()) + + let result = try XCTUnwrap(_result) + XCTAssertEqual(result.rows, []) + XCTAssertEqual(result.metadata.command, "") + } + func testDecodeIntegers() { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } diff --git a/Tests/IntegrationTests/PostgresClientTests.swift b/Tests/IntegrationTests/PostgresClientTests.swift index d6d89dc3..9ac92754 100644 --- a/Tests/IntegrationTests/PostgresClientTests.swift +++ b/Tests/IntegrationTests/PostgresClientTests.swift @@ -42,6 +42,145 @@ final class PostgresClientTests: XCTestCase { taskGroup.cancelAll() } } + + func testTransaction() async throws { + var mlogger = Logger(label: "test") + mlogger.logLevel = .debug + let logger = mlogger + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 8) + self.addTeardownBlock { + try await eventLoopGroup.shutdownGracefully() + } + + let tableName = "test_client_transactions" + + let clientConfig = PostgresClient.Configuration.makeTestConfiguration() + let client = PostgresClient(configuration: clientConfig, eventLoopGroup: eventLoopGroup, backgroundLogger: logger) + + do { + try await withThrowingTaskGroup(of: Void.self) { taskGroup in + taskGroup.addTask { + await client.run() + } + + try await client.query( + """ + CREATE TABLE IF NOT EXISTS "\(unescaped: tableName)" ( + id INT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, + uuid UUID NOT NULL + ); + """, + logger: logger + ) + + let iterations = 1000 + + for _ in 0..(0) conn?.addListener(channel: "example") { context, notification in - receivedNotifications.append(notification) + receivedNotifications.wrappingIncrement(ordering: .relaxed) + XCTAssertEqual(notification.channel, "example") + XCTAssertEqual(notification.payload, "") } XCTAssertNoThrow(_ = try conn?.simpleQuery("LISTEN example").wait()) XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait()) // Notifications are asynchronous, so we should run at least one more query to make sure we'll have received the notification response by then XCTAssertNoThrow(_ = try conn?.simpleQuery("SELECT 1").wait()) - XCTAssertEqual(receivedNotifications.count, 1) - XCTAssertEqual(receivedNotifications.first?.channel, "example") - XCTAssertEqual(receivedNotifications.first?.payload, "") + XCTAssertEqual(receivedNotifications.load(ordering: .relaxed), 1) } func testNotificationsNonEmptyPayload() { var conn: PostgresConnection? XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow( try conn?.close().wait() ) } - var receivedNotifications: [PostgresMessage.NotificationResponse] = [] + let receivedNotifications = ManagedAtomic(0) conn?.addListener(channel: "example") { context, notification in - receivedNotifications.append(notification) + receivedNotifications.wrappingIncrement(ordering: .relaxed) + XCTAssertEqual(notification.channel, "example") + XCTAssertEqual(notification.payload, "Notification payload example") } XCTAssertNoThrow(_ = try conn?.simpleQuery("LISTEN example").wait()) XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example, 'Notification payload example'").wait()) // Notifications are asynchronous, so we should run at least one more query to make sure we'll have received the notification response by then XCTAssertNoThrow(_ = try conn?.simpleQuery("SELECT 1").wait()) - XCTAssertEqual(receivedNotifications.count, 1) - XCTAssertEqual(receivedNotifications.first?.channel, "example") - XCTAssertEqual(receivedNotifications.first?.payload, "Notification payload example") + XCTAssertEqual(receivedNotifications.load(ordering: .relaxed), 1) } func testNotificationsRemoveHandlerWithinHandler() { var conn: PostgresConnection? XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow( try conn?.close().wait() ) } - var receivedNotifications = 0 + let receivedNotifications = ManagedAtomic(0) conn?.addListener(channel: "example") { context, notification in - receivedNotifications += 1 + receivedNotifications.wrappingIncrement(ordering: .relaxed) context.stop() } XCTAssertNoThrow(_ = try conn?.simpleQuery("LISTEN example").wait()) XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait()) XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait()) XCTAssertNoThrow(_ = try conn?.simpleQuery("SELECT 1").wait()) - XCTAssertEqual(receivedNotifications, 1) + XCTAssertEqual(receivedNotifications.load(ordering: .relaxed), 1) } func testNotificationsRemoveHandlerOutsideHandler() { var conn: PostgresConnection? XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow( try conn?.close().wait() ) } - var receivedNotifications = 0 + let receivedNotifications = ManagedAtomic(0) let context = conn?.addListener(channel: "example") { context, notification in - receivedNotifications += 1 + receivedNotifications.wrappingIncrement(ordering: .relaxed) } XCTAssertNotNil(context) XCTAssertNoThrow(_ = try conn?.simpleQuery("LISTEN example").wait()) @@ -173,47 +174,47 @@ final class PostgresNIOTests: XCTestCase { context?.stop() XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait()) XCTAssertNoThrow(_ = try conn?.simpleQuery("SELECT 1").wait()) - XCTAssertEqual(receivedNotifications, 1) + XCTAssertEqual(receivedNotifications.load(ordering: .relaxed), 1) } func testNotificationsMultipleRegisteredHandlers() { var conn: PostgresConnection? XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow( try conn?.close().wait() ) } - var receivedNotifications1 = 0 + let receivedNotifications1 = ManagedAtomic(0) conn?.addListener(channel: "example") { context, notification in - receivedNotifications1 += 1 + receivedNotifications1.wrappingIncrement(ordering: .relaxed) } - var receivedNotifications2 = 0 + let receivedNotifications2 = ManagedAtomic(0) conn?.addListener(channel: "example") { context, notification in - receivedNotifications2 += 1 + receivedNotifications2.wrappingIncrement(ordering: .relaxed) } XCTAssertNoThrow(_ = try conn?.simpleQuery("LISTEN example").wait()) XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait()) XCTAssertNoThrow(_ = try conn?.simpleQuery("SELECT 1").wait()) - XCTAssertEqual(receivedNotifications1, 1) - XCTAssertEqual(receivedNotifications2, 1) + XCTAssertEqual(receivedNotifications1.load(ordering: .relaxed), 1) + XCTAssertEqual(receivedNotifications2.load(ordering: .relaxed), 1) } func testNotificationsMultipleRegisteredHandlersRemoval() throws { var conn: PostgresConnection? XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow( try conn?.close().wait() ) } - var receivedNotifications1 = 0 + let receivedNotifications1 = ManagedAtomic(0) XCTAssertNotNil(conn?.addListener(channel: "example") { context, notification in - receivedNotifications1 += 1 + receivedNotifications1.wrappingIncrement(ordering: .relaxed) context.stop() }) - var receivedNotifications2 = 0 + let receivedNotifications2 = ManagedAtomic(0) XCTAssertNotNil(conn?.addListener(channel: "example") { context, notification in - receivedNotifications2 += 1 + receivedNotifications2.wrappingIncrement(ordering: .relaxed) }) XCTAssertNoThrow(_ = try conn?.simpleQuery("LISTEN example").wait()) XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait()) XCTAssertNoThrow(_ = try conn?.simpleQuery("NOTIFY example").wait()) XCTAssertNoThrow(_ = try conn?.simpleQuery("SELECT 1").wait()) - XCTAssertEqual(receivedNotifications1, 1) - XCTAssertEqual(receivedNotifications2, 2) + XCTAssertEqual(receivedNotifications1.load(ordering: .relaxed), 1) + XCTAssertEqual(receivedNotifications2.load(ordering: .relaxed), 2) } func testNotificationHandlerFiltersOnChannel() { @@ -929,6 +930,22 @@ final class PostgresNIOTests: XCTestCase { } } + func testJSONBDecodeString() { + var conn: PostgresConnection? + XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) + defer { XCTAssertNoThrow(try conn?.close().wait()) } + + do { + var rows: PostgresQueryResult? + XCTAssertNoThrow(rows = try conn?.query("select '{\"hello\": \"world\"}'::jsonb as data").wait()) + + var resultString: String? + XCTAssertNoThrow(resultString = try rows?.first?.decode(String.self, context: .default)) + + XCTAssertEqual(resultString, "{\"hello\": \"world\"}") + } + } + func testInt4RangeSerialize() async throws { let conn: PostgresConnection = try await PostgresConnection.test(on: eventLoop).get() self.addTeardownBlock { @@ -1031,28 +1048,6 @@ final class PostgresNIOTests: XCTestCase { } } - func testRemoteTLSServer() { - // postgres://uymgphwj:7_tHbREdRwkqAdu4KoIS7hQnNxr8J1LA@elmer.db.elephantsql.com:5432/uymgphwj - var conn: PostgresConnection? - let logger = Logger(label: "test") - let sslContext = try! NIOSSLContext(configuration: .makeClientConfiguration()) - let config = PostgresConnection.Configuration( - host: "elmer.db.elephantsql.com", - port: 5432, - username: "uymgphwj", - password: "7_tHbREdRwkqAdu4KoIS7hQnNxr8J1LA", - database: "uymgphwj", - tls: .require(sslContext) - ) - XCTAssertNoThrow(conn = try PostgresConnection.connect(on: eventLoop, configuration: config, id: 0, logger: logger).wait()) - defer { XCTAssertNoThrow( try conn?.close().wait() ) } - var rows: [PostgresRow]? - XCTAssertNoThrow(rows = try conn?.simpleQuery("SELECT version()").wait()) - XCTAssertEqual(rows?.count, 1) - let row = rows?.first?.makeRandomAccess() - XCTAssertEqual(row?[data: "version"].string?.contains("PostgreSQL"), true) - } - @available(*, deprecated, message: "Test deprecated functionality") func testFailingTLSConnectionClosesConnection() { // There was a bug (https://github.com/vapor/postgres-nio/issues/133) where we would hit @@ -1283,11 +1278,11 @@ final class PostgresNIOTests: XCTestCase { XCTAssertNoThrow(conn = try PostgresConnection.test(on: eventLoop).wait()) defer { XCTAssertNoThrow( try conn?.close().wait() ) } var queries: [[PostgresRow]]? - XCTAssertNoThrow(queries = try conn?.prepare(query: "SELECT $1::text as foo;", handler: { query in + XCTAssertNoThrow(queries = try conn?.prepare(query: "SELECT $1::text as foo;", handler: { [eventLoop] query in let a = query.execute(["a"]) let b = query.execute(["b"]) let c = query.execute(["c"]) - return EventLoopFuture.whenAllSucceed([a, b, c], on: self.eventLoop) + return EventLoopFuture.whenAllSucceed([a, b, c], on: eventLoop) }).wait()) XCTAssertEqual(queries?.count, 3) var resultIterator = queries?.makeIterator() diff --git a/Tests/IntegrationTests/Utilities.swift b/Tests/IntegrationTests/Utilities.swift index 001d9ee4..91dbb62e 100644 --- a/Tests/IntegrationTests/Utilities.swift +++ b/Tests/IntegrationTests/Utilities.swift @@ -24,9 +24,9 @@ extension PostgresConnection { } } - static func test(on eventLoop: EventLoop) -> EventLoopFuture { + static func test(on eventLoop: EventLoop, options: Configuration.Options? = nil) -> EventLoopFuture { let logger = Logger(label: "postgres.connection.test") - let config = PostgresConnection.Configuration( + var config = PostgresConnection.Configuration( host: env("POSTGRES_HOSTNAME") ?? "localhost", port: env("POSTGRES_PORT").flatMap(Int.init(_:)) ?? 5432, username: env("POSTGRES_USER") ?? "test_username", @@ -34,7 +34,10 @@ extension PostgresConnection { database: env("POSTGRES_DB") ?? "test_database", tls: .disable ) - + if let options { + config.options = options + } + return PostgresConnection.connect(on: eventLoop, configuration: config, id: 0, logger: logger) } diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift index 40e32468..872664af 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/ExtendedQueryStateMachineTests.swift @@ -20,7 +20,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) XCTAssertEqual(state.noDataReceived(), .wait) XCTAssertEqual(state.bindCompleteReceived(), .wait) - XCTAssertEqual(state.commandCompletedReceived("DELETE 1"), .succeedQuery(promise, with: .init(value: .noRows("DELETE 1"), logger: logger))) + XCTAssertEqual(state.commandCompletedReceived("DELETE 1"), .succeedQuery(promise, with: .init(value: .noRows(.tag("DELETE 1")), logger: logger))) XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) } @@ -77,7 +77,25 @@ class ExtendedQueryStateMachineTests: XCTestCase { XCTAssertEqual(state.commandCompletedReceived("SELECT 2"), .forwardStreamComplete([row5, row6], commandTag: "SELECT 2")) XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) } - + + func testExtendedQueryWithNoQuery() { + var state = ConnectionStateMachine.readyForQuery() + + let logger = Logger.psqlTest + let promise = EmbeddedEventLoop().makePromise(of: PSQLRowStream.self) + promise.fail(PSQLError.uncleanShutdown) // we don't care about the error at all. + let query: PostgresQuery = "-- some comments" + let queryContext = ExtendedQueryContext(query: query, logger: logger, promise: promise) + + XCTAssertEqual(state.enqueue(task: .extendedQuery(queryContext)), .sendParseDescribeBindExecuteSync(query)) + XCTAssertEqual(state.parseCompleteReceived(), .wait) + XCTAssertEqual(state.parameterDescriptionReceived(.init(dataTypes: [.int8])), .wait) + XCTAssertEqual(state.noDataReceived(), .wait) + XCTAssertEqual(state.bindCompleteReceived(), .wait) + XCTAssertEqual(state.emptyQueryResponseReceived(), .succeedQuery(promise, with: .init(value: .noRows(.emptyResponse), logger: logger))) + XCTAssertEqual(state.readyForQueryReceived(.idle), .fireEventReadyForQuery) + } + func testReceiveTotallyUnexpectedMessageInQuery() { var state = ConnectionStateMachine.readyForQuery() @@ -96,7 +114,7 @@ class ExtendedQueryStateMachineTests: XCTestCase { .failQuery(promise, with: psqlError, cleanupContext: .init(action: .close, tasks: [], error: psqlError, closePromise: nil))) } - func testExtendedQueryIsCancelledImmediatly() { + func testExtendedQueryIsCancelledImmediately() { var state = ConnectionStateMachine.readyForQuery() let logger = Logger.psqlTest diff --git a/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift b/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift index f6c1ddf7..e35e93f7 100644 --- a/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift +++ b/Tests/PostgresNIOTests/New/Connection State Machine/PreparedStatementStateMachineTests.swift @@ -28,7 +28,7 @@ class PreparedStatementStateMachineTests: XCTestCase { XCTAssertEqual(preparationCompleteAction.statements.count, 1) XCTAssertNil(preparationCompleteAction.rowDescription) firstPreparedStatement.promise.succeed(PSQLRowStream( - source: .noRows(.success("tag")), + source: .noRows(.success(.tag("tag"))), eventLoop: eventLoop, logger: .psqlTest )) @@ -46,7 +46,7 @@ class PreparedStatementStateMachineTests: XCTestCase { return } secondPreparedStatement.promise.succeed(PSQLRowStream( - source: .noRows(.success("tag")), + source: .noRows(.success(.tag("tag"))), eventLoop: eventLoop, logger: .psqlTest )) @@ -135,12 +135,12 @@ class PreparedStatementStateMachineTests: XCTestCase { XCTAssertNil(preparationCompleteAction.rowDescription) firstPreparedStatement.promise.succeed(PSQLRowStream( - source: .noRows(.success("tag")), + source: .noRows(.success(.tag("tag"))), eventLoop: eventLoop, logger: .psqlTest )) secondPreparedStatement.promise.succeed(PSQLRowStream( - source: .noRows(.success("tag")), + source: .noRows(.success(.tag("tag"))), eventLoop: eventLoop, logger: .psqlTest )) diff --git a/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift index 769bde4b..3f406598 100644 --- a/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/Date+PSQLCodableTests.swift @@ -14,7 +14,7 @@ class Date_PSQLCodableTests: XCTestCase { var result: Date? XCTAssertNoThrow(result = try Date(from: &buffer, type: .timestamptz, format: .binary, context: .default)) - XCTAssertEqual(value, result) + XCTAssertEqual(value.timeIntervalSince1970, result?.timeIntervalSince1970 ?? 0, accuracy: 0.001) } func testDecodeRandomDate() { diff --git a/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift b/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift index 6ff35130..c1843c2a 100644 --- a/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift +++ b/Tests/PostgresNIOTests/New/Data/String+PSQLCodableTests.swift @@ -31,18 +31,6 @@ class String_PSQLCodableTests: XCTestCase { } } - func testDecodeFailureFromInvalidType() { - let buffer = ByteBuffer() - let dataTypes: [PostgresDataType] = [.bool, .float4Array, .float8Array] - - for dataType in dataTypes { - var loopBuffer = buffer - XCTAssertThrowsError(try String(from: &loopBuffer, type: dataType, format: .binary, context: .default)) { - XCTAssertEqual($0 as? PostgresDecodingError.Code, .typeMismatch) - } - } - } - func testDecodeFromUUID() { let uuid = UUID() var buffer = ByteBuffer() @@ -64,4 +52,15 @@ class String_PSQLCodableTests: XCTestCase { XCTAssertEqual($0 as? PostgresDecodingError.Code, .failure) } } + + func testDecodeFromJSONB() { + let json = #"{"hello": "world"}"# + var buffer = ByteBuffer() + buffer.writeInteger(UInt8(1)) + buffer.writeString(json) + + var decoded: String? + XCTAssertNoThrow(decoded = try String(from: &buffer, type: .jsonb, format: .binary, context: .default)) + XCTAssertEqual(decoded, json) + } } diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift index 9614bf1e..0c6b37ef 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift @@ -28,6 +28,8 @@ struct PSQLBackendMessageEncoder: MessageToByteEncoder { case .commandComplete(let string): self.encode(messageID: message.id, payload: StringPayload(string), into: &buffer) + case .copyInResponse(let copyInResponse): + self.encode(messageID: message.id, payload: copyInResponse, into: &buffer) case .dataRow(let row): self.encode(messageID: message.id, payload: row, into: &buffer) @@ -99,6 +101,8 @@ extension PostgresBackendMessage { return .closeComplete case .commandComplete: return .commandComplete + case .copyInResponse: + return .copyInResponse case .dataRow: return .dataRow case .emptyQueryResponse: @@ -184,6 +188,16 @@ extension PostgresBackendMessage.BackendKeyData: PSQLMessagePayloadEncodable { } } +extension PostgresBackendMessage.CopyInResponse: PSQLMessagePayloadEncodable { + public func encode(into buffer: inout ByteBuffer) { + buffer.writeInteger(Int8(self.format.rawValue)) + buffer.writeInteger(Int16(self.columnFormats.count)) + for columnFormat in columnFormats { + buffer.writeInteger(Int16(columnFormat.rawValue)) + } + } +} + extension DataRow: PSQLMessagePayloadEncodable { public func encode(into buffer: inout ByteBuffer) { buffer.writeInteger(self.columnCount, as: Int16.self) diff --git a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift index 55ccd0a9..d913da22 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift @@ -168,6 +168,18 @@ extension PostgresFrontendMessage { ) ) + case .copyData: + return .copyData(CopyData(data: buffer)) + + case .copyDone: + return .copyDone + + case .copyFail: + guard let message = buffer.readNullTerminatedString() else { + throw PSQLPartialDecodingError.fieldNotDecodable(type: String.self) + } + return .copyFail(CopyFail(message: message)) + case .close: preconditionFailure("TODO: Unimplemented") diff --git a/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift b/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift index 2532959a..5fc8144b 100644 --- a/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift +++ b/Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift @@ -36,6 +36,14 @@ enum PostgresFrontendMessage: Equatable { let secretKey: Int32 } + struct CopyData: Hashable { + var data: ByteBuffer + } + + struct CopyFail: Hashable { + var message: String + } + enum Close: Hashable { case preparedStatement(String) case portal(String) @@ -170,6 +178,9 @@ enum PostgresFrontendMessage: Equatable { case bind(Bind) case cancel(Cancel) + case copyData(CopyData) + case copyDone + case copyFail(CopyFail) case close(Close) case describe(Describe) case execute(Execute) @@ -186,6 +197,9 @@ enum PostgresFrontendMessage: Equatable { enum ID: UInt8, Equatable { case bind + case copyData + case copyDone + case copyFail case close case describe case execute @@ -201,12 +215,18 @@ enum PostgresFrontendMessage: Equatable { switch rawValue { case UInt8(ascii: "B"): self = .bind + case UInt8(ascii: "c"): + self = .copyDone case UInt8(ascii: "C"): self = .close + case UInt8(ascii: "d"): + self = .copyData case UInt8(ascii: "D"): self = .describe case UInt8(ascii: "E"): self = .execute + case UInt8(ascii: "f"): + self = .copyFail case UInt8(ascii: "H"): self = .flush case UInt8(ascii: "P"): @@ -230,6 +250,12 @@ enum PostgresFrontendMessage: Equatable { switch self { case .bind: return UInt8(ascii: "B") + case .copyData: + return UInt8(ascii: "d") + case .copyDone: + return UInt8(ascii: "c") + case .copyFail: + return UInt8(ascii: "f") case .close: return UInt8(ascii: "C") case .describe: @@ -263,6 +289,12 @@ extension PostgresFrontendMessage { return .bind case .cancel: preconditionFailure("Cancel messages don't have an identifier") + case .copyData: + return .copyData + case .copyDone: + return .copyDone + case .copyFail: + return .copyFail case .close: return .close case .describe: diff --git a/Tests/PostgresNIOTests/New/Messages/CopyTests.swift b/Tests/PostgresNIOTests/New/Messages/CopyTests.swift new file mode 100644 index 00000000..de686ae5 --- /dev/null +++ b/Tests/PostgresNIOTests/New/Messages/CopyTests.swift @@ -0,0 +1,147 @@ +import XCTest +import NIOCore +import NIOTestUtils +@testable import PostgresNIO + +class CopyTests: XCTestCase { + func testDecodeCopyInResponseMessage() throws { + let expected: [PostgresBackendMessage] = [ + .copyInResponse(.init(format: .textual, columnFormats: [.textual, .textual])), + .copyInResponse(.init(format: .binary, columnFormats: [.binary, .binary])), + .copyInResponse(.init(format: .binary, columnFormats: [.textual, .binary])) + ] + + var buffer = ByteBuffer() + + for message in expected { + guard case .copyInResponse(let message) = message else { + return XCTFail("Expected only to get copyInResponse here!") + } + buffer.writeBackendMessage(id: .copyInResponse ) { buffer in + buffer.writeInteger(Int8(message.format.rawValue)) + buffer.writeInteger(Int16(message.columnFormats.count)) + for columnFormat in message.columnFormats { + buffer.writeInteger(UInt16(columnFormat.rawValue)) + } + } + } + try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, expected)], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) } + ) + } + + func testDecodeFailureBecauseOfEmptyMessage() { + var buffer = ByteBuffer() + buffer.writeBackendMessage(id: .copyInResponse) { _ in} + + XCTAssertThrowsError( + try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, [])], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) } + ) + ) { + XCTAssert($0 is PostgresMessageDecodingError) + } + } + + + func testDecodeFailureBecauseOfInvalidFormat() { + var buffer = ByteBuffer() + buffer.writeBackendMessage(id: .copyInResponse) { buffer in + buffer.writeInteger(Int8(20)) // Only 0 and 1 are valid formats + } + + XCTAssertThrowsError( + try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, [])], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) } + ) + ) { + XCTAssert($0 is PostgresMessageDecodingError) + } + } + + func testDecodeFailureBecauseOfMissingColumnNumber() { + var buffer = ByteBuffer() + buffer.writeBackendMessage(id: .copyInResponse) { buffer in + buffer.writeInteger(Int8(0)) + } + + XCTAssertThrowsError( + try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, [])], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) } + ) + ) { + XCTAssert($0 is PostgresMessageDecodingError) + } + } + + + func testDecodeFailureBecauseOfMissingColumns() { + var buffer = ByteBuffer() + buffer.writeBackendMessage(id: .copyInResponse) { buffer in + buffer.writeInteger(Int8(0)) + buffer.writeInteger(Int16(20)) // 20 columns promised, none given + } + + XCTAssertThrowsError( + try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, [])], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) } + ) + ) { + XCTAssert($0 is PostgresMessageDecodingError) + } + } + + func testDecodeFailureBecauseOfInvalidColumnFormat() { + var buffer = ByteBuffer() + buffer.writeBackendMessage(id: .copyInResponse) { buffer in + buffer.writeInteger(Int8(0)) + buffer.writeInteger(Int16(1)) + buffer.writeInteger(Int8(20)) // Only 0 and 1 are valid formats + } + + XCTAssertThrowsError( + try ByteToMessageDecoderVerifier.verifyDecoder( + inputOutputPairs: [(buffer, [])], + decoderFactory: { PostgresBackendMessageDecoder(hasAlreadyReceivedBytes: true) } + ) + ) { + XCTAssert($0 is PostgresMessageDecodingError) + } + } + + func testEncodeCopyDataHeader() { + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.copyDataHeader(dataLength: 3) + var byteBuffer = encoder.flushBuffer() + + XCTAssertEqual(byteBuffer.readableBytes, 5) + XCTAssertEqual(PostgresFrontendMessage.ID.copyData.rawValue, byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), 7) + } + + func testEncodeCopyDone() { + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.copyDone() + var byteBuffer = encoder.flushBuffer() + + XCTAssertEqual(byteBuffer.readableBytes, 5) + XCTAssertEqual(PostgresFrontendMessage.ID.copyDone.rawValue, byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), 4) + } + + func testEncodeCopyFail() { + var encoder = PostgresFrontendMessageEncoder(buffer: .init()) + encoder.copyFail(message: "Oh, no :(") + var byteBuffer = encoder.flushBuffer() + + XCTAssertEqual(byteBuffer.readableBytes, 15) + XCTAssertEqual(PostgresFrontendMessage.ID.copyFail.rawValue, byteBuffer.readInteger(as: UInt8.self)) + XCTAssertEqual(byteBuffer.readInteger(as: Int32.self), 14) + XCTAssertEqual(byteBuffer.readNullTerminatedString(), "Oh, no :(") + } +} diff --git a/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift b/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift index 9a1e9e41..65ca26c3 100644 --- a/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift +++ b/Tests/PostgresNIOTests/New/PSQLRowStreamTests.swift @@ -12,7 +12,7 @@ final class PSQLRowStreamTests: XCTestCase { func testEmptyStream() { let stream = PSQLRowStream( - source: .noRows(.success("INSERT 0 1")), + source: .noRows(.success(.tag("INSERT 0 1"))), eventLoop: self.eventLoop, logger: self.logger ) diff --git a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift index dfdcc53e..206f38a3 100644 --- a/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift @@ -48,7 +48,7 @@ class PostgresChannelHandlerTests: XCTestCase { var config = self.testConnectionConfiguration() XCTAssertNoThrow(config.tls = .require(try NIOSSLContext(configuration: .makeClientConfiguration()))) var addSSLCallbackIsHit = false - let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop) { channel in + let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop) { channel, _ in addSSLCallbackIsHit = true } let embedded = EmbeddedChannel(handlers: [ @@ -84,7 +84,7 @@ class PostgresChannelHandlerTests: XCTestCase { var config = self.testConnectionConfiguration() XCTAssertNoThrow(config.tls = .require(try NIOSSLContext(configuration: .makeClientConfiguration()))) var addSSLCallbackIsHit = false - let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop) { channel in + let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop) { channel, _ in addSSLCallbackIsHit = true } let eventHandler = TestEventHandler() @@ -114,7 +114,7 @@ class PostgresChannelHandlerTests: XCTestCase { func testSSLUnsupportedClosesConnection() throws { let config = self.testConnectionConfiguration(tls: .require(try NIOSSLContext(configuration: .makeClientConfiguration()))) - let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop) { channel in + let handler = PostgresChannelHandler(configuration: config, eventLoop: self.eventLoop) { channel, _ in XCTFail("This callback should never be exectuded") throw PSQLError.sslUnsupported } @@ -124,7 +124,7 @@ class PostgresChannelHandlerTests: XCTestCase { handler ], loop: self.eventLoop) let eventHandler = TestEventHandler() - try embedded.pipeline.addHandler(eventHandler, position: .last).wait() + try embedded.pipeline.syncOperations.addHandler(eventHandler, position: .last) embedded.connect(to: try .init(ipAddress: "0.0.0.0", port: 5432), promise: nil) XCTAssertTrue(embedded.isActive) diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index f2cd96f8..d0f8e2b0 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -38,6 +38,48 @@ class PostgresConnectionTests: XCTestCase { } } + func testOptionsAreSentOnTheWire() async throws { + let eventLoop = NIOAsyncTestingEventLoop() + let channel = try await NIOAsyncTestingChannel(loop: eventLoop) { channel in + try channel.pipeline.syncOperations.addHandlers(ReverseByteToMessageHandler(PSQLFrontendMessageDecoder())) + try channel.pipeline.syncOperations.addHandlers(ReverseMessageToByteHandler(PSQLBackendMessageEncoder())) + } + try await channel.connect(to: .makeAddressResolvingHost("localhost", port: 5432)) + + let configuration = { + var config = PostgresConnection.Configuration( + establishedChannel: channel, + username: "username", + password: "postgres", + database: "database" + ) + config.options.additionalStartupParameters = [ + ("DateStyle", "ISO, MDY"), + ("application_name", "postgres-nio-test"), + ("server_encoding", "UTF8"), + ("integer_datetimes", "on"), + ("client_encoding", "UTF8"), + ("TimeZone", "Etc/UTC"), + ("is_superuser", "on"), + ("server_version", "13.1 (Debian 13.1-1.pgdg100+1)"), + ("session_authorization", "postgres"), + ("IntervalStyle", "postgres"), + ("standard_conforming_strings", "on") + ] + return config + }() + + async let connectionPromise = PostgresConnection.connect(on: eventLoop, configuration: configuration, id: 1, logger: .psqlTest) + let message = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) + XCTAssertEqual(message, .startup(.versionThree(parameters: .init(user: "username", database: "database", options: configuration.options.additionalStartupParameters, replication: .false)))) + try await channel.writeInbound(PostgresBackendMessage.authentication(.ok)) + try await channel.writeInbound(PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 5678))) + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + + let connection = try await connectionPromise + try await connection.close() + } + func testSimpleListen() async throws { let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() @@ -51,7 +93,7 @@ class PostgresConnectionTests: XCTestCase { } let listenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(listenMessage.parse.query, "LISTEN foo;") + XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#) try await channel.writeInbound(PostgresBackendMessage.parseComplete) try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) @@ -63,7 +105,7 @@ class PostgresConnectionTests: XCTestCase { try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo"))) let unlistenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(unlistenMessage.parse.query, "UNLISTEN foo;") + XCTAssertEqual(unlistenMessage.parse.query, #"UNLISTEN "foo";"#) try await channel.writeInbound(PostgresBackendMessage.parseComplete) try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) @@ -111,7 +153,7 @@ class PostgresConnectionTests: XCTestCase { } let listenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(listenMessage.parse.query, "LISTEN foo;") + XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#) try await channel.writeInbound(PostgresBackendMessage.parseComplete) try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) @@ -124,7 +166,7 @@ class PostgresConnectionTests: XCTestCase { try await channel.writeInbound(PostgresBackendMessage.notification(.init(backendPID: 12, channel: "foo", payload: "wooohooo2"))) let unlistenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(unlistenMessage.parse.query, "UNLISTEN foo;") + XCTAssertEqual(unlistenMessage.parse.query, #"UNLISTEN "foo";"#) try await channel.writeInbound(PostgresBackendMessage.parseComplete) try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) @@ -145,7 +187,7 @@ class PostgresConnectionTests: XCTestCase { func testSimpleListenConnectionDrops() async throws { let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - try await withThrowingTaskGroup(of: Void.self) { taskGroup in + try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup in taskGroup.addTask { let events = try await connection.listen("foo") var iterator = events.makeAsyncIterator() @@ -155,12 +197,12 @@ class PostgresConnectionTests: XCTestCase { _ = try await iterator.next() XCTFail("Did not expect to not throw") } catch { - self.logger.error("error", metadata: ["error": "\(error)"]) + logger.error("error", metadata: ["error": "\(error)"]) } } let listenMessage = try await channel.waitForUnpreparedRequest() - XCTAssertEqual(listenMessage.parse.query, "LISTEN foo;") + XCTAssertEqual(listenMessage.parse.query, #"LISTEN "foo";"#) try await channel.writeInbound(PostgresBackendMessage.parseComplete) try await channel.writeInbound(PostgresBackendMessage.parameterDescription(.init(dataTypes: []))) @@ -184,10 +226,10 @@ class PostgresConnectionTests: XCTestCase { func testCloseGracefullyClosesWhenInternalQueueIsEmpty() async throws { let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup async throws -> () in for _ in 1...2 { taskGroup.addTask { - let rows = try await connection.query("SELECT 1;", logger: self.logger) + let rows = try await connection.query("SELECT 1;", logger: logger) var iterator = rows.decode(Int.self).makeAsyncIterator() let first = try await iterator.next() XCTAssertEqual(first, 1) @@ -244,10 +286,10 @@ class PostgresConnectionTests: XCTestCase { func testCloseClosesImmediatly() async throws { let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() - try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + try await withThrowingTaskGroup(of: Void.self) { [logger] taskGroup async throws -> () in for _ in 1...2 { taskGroup.addTask { - try await connection.query("SELECT 1;", logger: self.logger) + try await connection.query("SELECT 1;", logger: logger) } } @@ -277,8 +319,9 @@ class PostgresConnectionTests: XCTestCase { func testIfServerJustClosesTheErrorReflectsThat() async throws { let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + let logger = self.logger - async let response = try await connection.query("SELECT 1;", logger: self.logger) + async let response = try await connection.query("SELECT 1;", logger: logger) let listenMessage = try await channel.waitForUnpreparedRequest() XCTAssertEqual(listenMessage.parse.query, "SELECT 1;") @@ -374,6 +417,16 @@ class PostgresConnectionTests: XCTestCase { } } + func testWeDontCrashOnUnexpectedChannelEvents() async throws { + let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() + + enum MyEvent { + case pleaseDontCrash + } + channel.pipeline.fireUserInboundEventTriggered(MyEvent.pleaseDontCrash) + try await connection.close() + } + func testSerialExecutionOfSamePreparedStatement() async throws { let (connection, channel) = try await self.makeTestConnectionWithAsyncTestingChannel() @@ -587,10 +640,10 @@ class PostgresConnectionTests: XCTestCase { func makeTestConnectionWithAsyncTestingChannel() async throws -> (PostgresConnection, NIOAsyncTestingChannel) { let eventLoop = NIOAsyncTestingEventLoop() - let channel = await NIOAsyncTestingChannel(handlers: [ - ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()), - ReverseMessageToByteHandler(PSQLBackendMessageEncoder()), - ], loop: eventLoop) + let channel = try await NIOAsyncTestingChannel(loop: eventLoop) { channel in + try channel.pipeline.syncOperations.addHandlers(ReverseByteToMessageHandler(PSQLFrontendMessageDecoder())) + try channel.pipeline.syncOperations.addHandlers(ReverseMessageToByteHandler(PSQLBackendMessageEncoder())) + } try await channel.connect(to: .makeAddressResolvingHost("localhost", port: 5432)) let configuration = PostgresConnection.Configuration( @@ -600,7 +653,8 @@ class PostgresConnectionTests: XCTestCase { database: "database" ) - async let connectionPromise = PostgresConnection.connect(on: eventLoop, configuration: configuration, id: 1, logger: self.logger) + let logger = self.logger + async let connectionPromise = PostgresConnection.connect(on: eventLoop, configuration: configuration, id: 1, logger: logger) let message = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self) XCTAssertEqual(message, .startup(.versionThree(parameters: .init(user: "username", database: "database", options: [], replication: .false)))) try await channel.writeInbound(PostgresBackendMessage.authentication(.ok)) diff --git a/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift b/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift index 816daf04..9d662252 100644 --- a/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresRowSequenceTests.swift @@ -1,6 +1,6 @@ import Atomics import NIOEmbedded -import Dispatch +import NIOPosix import XCTest @testable import PostgresNIO import NIOCore @@ -8,10 +8,10 @@ import Logging final class PostgresRowSequenceTests: XCTestCase { let logger = Logger(label: "PSQLRowStreamTests") - let eventLoop = EmbeddedEventLoop() func testBackpressureWorks() async throws { let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( source: .stream( [ @@ -19,7 +19,7 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: embeddedEventLoop, logger: self.logger ) @@ -41,6 +41,7 @@ final class PostgresRowSequenceTests: XCTestCase { func testCancellationWorksWhileIterating() async throws { let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( source: .stream( [ @@ -48,7 +49,7 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: embeddedEventLoop, logger: self.logger ) @@ -72,6 +73,7 @@ final class PostgresRowSequenceTests: XCTestCase { func testCancellationWorksBeforeIterating() async throws { let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( source: .stream( [ @@ -79,7 +81,7 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: embeddedEventLoop, logger: self.logger ) @@ -97,6 +99,7 @@ final class PostgresRowSequenceTests: XCTestCase { func testDroppingTheSequenceCancelsTheSource() async throws { let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( source: .stream( [ @@ -104,7 +107,7 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: embeddedEventLoop, logger: self.logger ) @@ -117,6 +120,7 @@ final class PostgresRowSequenceTests: XCTestCase { func testStreamBasedOnCompletedQuery() async throws { let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( source: .stream( [ @@ -124,7 +128,7 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: embeddedEventLoop, logger: self.logger ) @@ -144,6 +148,7 @@ final class PostgresRowSequenceTests: XCTestCase { func testStreamIfInitializedWithAllData() async throws { let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( source: .stream( [ @@ -151,7 +156,7 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: embeddedEventLoop, logger: self.logger ) @@ -172,6 +177,7 @@ final class PostgresRowSequenceTests: XCTestCase { func testStreamIfInitializedWithError() async throws { let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( source: .stream( [ @@ -179,7 +185,7 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: embeddedEventLoop, logger: self.logger ) @@ -200,6 +206,7 @@ final class PostgresRowSequenceTests: XCTestCase { func testSucceedingRowContinuationsWorks() async throws { let dataSource = MockRowDataSource() + let eventLoop = NIOSingletons.posixEventLoopGroup.next() let stream = PSQLRowStream( source: .stream( [ @@ -207,14 +214,14 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: eventLoop, logger: self.logger ) - let rowSequence = stream.asyncSequence() + let rowSequence = try await eventLoop.submit { stream.asyncSequence() }.get() var rowIterator = rowSequence.makeAsyncIterator() - DispatchQueue.main.asyncAfter(deadline: .now() + .seconds(1)) { + eventLoop.scheduleTask(in: .seconds(1)) { let dataRows: [DataRow] = (0..<1).map { [ByteBuffer(integer: Int64($0))] } stream.receive(dataRows) } @@ -222,7 +229,7 @@ final class PostgresRowSequenceTests: XCTestCase { let row1 = try await rowIterator.next() XCTAssertEqual(try row1?.decode(Int.self), 0) - DispatchQueue.main.asyncAfter(deadline: .now() + .seconds(1)) { + eventLoop.scheduleTask(in: .seconds(1)) { stream.receive(completion: .success("SELECT 1")) } @@ -232,6 +239,7 @@ final class PostgresRowSequenceTests: XCTestCase { func testFailingRowContinuationsWorks() async throws { let dataSource = MockRowDataSource() + let eventLoop = NIOSingletons.posixEventLoopGroup.next() let stream = PSQLRowStream( source: .stream( [ @@ -239,14 +247,14 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: eventLoop, logger: self.logger ) - let rowSequence = stream.asyncSequence() + let rowSequence = try await eventLoop.submit { stream.asyncSequence() }.get() var rowIterator = rowSequence.makeAsyncIterator() - DispatchQueue.main.asyncAfter(deadline: .now() + .seconds(1)) { + eventLoop.scheduleTask(in: .seconds(1)) { let dataRows: [DataRow] = (0..<1).map { [ByteBuffer(integer: Int64($0))] } stream.receive(dataRows) } @@ -254,7 +262,7 @@ final class PostgresRowSequenceTests: XCTestCase { let row1 = try await rowIterator.next() XCTAssertEqual(try row1?.decode(Int.self), 0) - DispatchQueue.main.asyncAfter(deadline: .now() + .seconds(1)) { + eventLoop.scheduleTask(in: .seconds(1)) { stream.receive(completion: .failure(PSQLError.serverClosedConnection(underlying: nil))) } @@ -268,6 +276,7 @@ final class PostgresRowSequenceTests: XCTestCase { func testAdaptiveRowBufferShrinksAndGrows() async throws { let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( source: .stream( [ @@ -275,7 +284,7 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: embeddedEventLoop, logger: self.logger ) @@ -332,6 +341,7 @@ final class PostgresRowSequenceTests: XCTestCase { func testAdaptiveRowShrinksToMin() async throws { let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( source: .stream( [ @@ -339,7 +349,7 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: embeddedEventLoop, logger: self.logger ) @@ -386,6 +396,7 @@ final class PostgresRowSequenceTests: XCTestCase { func testStreamBufferAcceptsNewRowsEventhoughItDidntAskForIt() async throws { let dataSource = MockRowDataSource() + let embeddedEventLoop = EmbeddedEventLoop() let stream = PSQLRowStream( source: .stream( [ @@ -393,7 +404,7 @@ final class PostgresRowSequenceTests: XCTestCase { ], dataSource ), - eventLoop: self.eventLoop, + eventLoop: embeddedEventLoop, logger: self.logger ) diff --git a/Tests/PostgresNIOTests/New/PostgresRowTests.swift b/Tests/PostgresNIOTests/New/PostgresRowTests.swift index 7aa4c7e6..4a34e3b0 100644 --- a/Tests/PostgresNIOTests/New/PostgresRowTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresRowTests.swift @@ -172,7 +172,7 @@ final class PostgresRowTests: XCTestCase { name: "name", tableOID: 1, columnAttributeNumber: 1, - dataType: .int8, + dataType: .text, dataTypeSize: 0, dataTypeModifier: 0, format: .binary @@ -185,7 +185,7 @@ final class PostgresRowTests: XCTestCase { columns: rowDescription ) - XCTAssertThrowsError(try row.decode((UUID?, String).self)) { error in + XCTAssertThrowsError(try row.decode((UUID?, Int).self)) { error in guard let psqlError = error as? PostgresDecodingError else { return XCTFail("Unexpected error type") } XCTAssertEqual(psqlError.columnName, "name") @@ -194,8 +194,8 @@ final class PostgresRowTests: XCTestCase { XCTAssertEqual(psqlError.file, #fileID) XCTAssertEqual(psqlError.postgresData, ByteBuffer(integer: 123)) XCTAssertEqual(psqlError.postgresFormat, .binary) - XCTAssertEqual(psqlError.postgresType, .int8) - XCTAssert(psqlError.targetType == String.self) + XCTAssertEqual(psqlError.postgresType, .text) + XCTAssert(psqlError.targetType == Int.self) } } }