diff --git a/Package.swift b/Package.swift index 309bba4..478eca2 100644 --- a/Package.swift +++ b/Package.swift @@ -33,6 +33,7 @@ let package = Package( .package(url: "https://github.com/apple/swift-nio.git", from: "2.56.0"), .package(url: "https://github.com/apple/swift-log.git", from: "1.0.0"), .package(url: "https://github.com/swift-server/swift-service-lifecycle.git", from: "2.0.0"), + .package(url: "https://github.com/vapor/postgres-nio.git", from: "1.19.0"), ], targets: [ .target( @@ -43,6 +44,7 @@ let package = Package( .product(name: "NIOEmbedded", package: "swift-nio"), .product(name: "Logging", package: "swift-log"), .product(name: "ServiceLifecycle", package: "swift-service-lifecycle"), + .product(name: "_ConnectionPoolModule", package: "postgres-nio"), ] ), .testTarget( diff --git a/Sources/Memcache/MemcacheConnection.swift b/Sources/Memcache/MemcacheConnection.swift index 0ef63bd..ddc6677 100644 --- a/Sources/Memcache/MemcacheConnection.swift +++ b/Sources/Memcache/MemcacheConnection.swift @@ -11,8 +11,9 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// -@_spi(AsyncChannel) +import _ConnectionPoolModule +import Atomics import NIOCore import NIOPosix import ServiceLifecycle @@ -20,7 +21,17 @@ import ServiceLifecycle /// An actor to create a connection to a Memcache server. /// /// This actor can be used to send commands to the server. -public actor MemcacheConnection: Service { +public actor MemcacheConnection: Service, PooledConnection { + public typealias ID = Int + public let id: ID + private static var nextID: ManagedAtomic = ManagedAtomic(0) + + private let closePromise: EventLoopPromise + + public var closeFuture: EventLoopFuture { + return self.closePromise.futureResult + } + private typealias StreamElement = (MemcacheRequest, CheckedContinuation) private let host: String private let port: Int @@ -56,23 +67,63 @@ public actor MemcacheConnection: Service { private var state: State - /// Initialize a new MemcacheConnection. + /// Initialize a new MemcacheConnection, with an option to specify an ID. + /// If no ID is provided, a default value is used. /// /// - Parameters: /// - host: The host address of the Memcache server. /// - port: The port number of the Memcache server. /// - eventLoopGroup: The event loop group to use for this connection. - public init(host: String, port: Int, eventLoopGroup: EventLoopGroup) { + /// - id: The unique identifier for the connection (optional). + public init(host: String, port: Int, id: ID? = nil, eventLoopGroup: EventLoopGroup) { self.host = host self.port = port + self.id = id ?? MemcacheConnection.nextID.wrappingIncrementThenLoad(ordering: .sequentiallyConsistent) let (stream, continuation) = AsyncStream.makeStream() let bufferAllocator = ByteBufferAllocator() - self.state = .initial( - eventLoopGroup: eventLoopGroup, - bufferAllocator: bufferAllocator, - requestStream: stream, - requestContinuation: continuation - ) + self.closePromise = eventLoopGroup.next().makePromise(of: Void.self) + self.state = .initial(eventLoopGroup: eventLoopGroup, bufferAllocator: bufferAllocator, requestStream: stream, requestContinuation: continuation) + } + + deinit { + // Fulfill the promise if it has not been fulfilled yet + closePromise.fail(MemcacheError(code: .connectionShutdown, + message: "MemcacheConnection deinitialized without closing", + cause: nil, + location: .here())) + } + + /// Closes the connection. This method is responsible for properly shutting down + /// and cleaning up resources associated with the connection. + public nonisolated func close() { + Task { + await self.closeConnection() + } + } + + private func closeConnection() async { + switch self.state { + case .running(_, let channel, _, _): + channel.channel.close().cascade(to: self.closePromise) + default: + self.closePromise.succeed(()) + } + self.state = .finished + } + + /// Registers a closure to be called when the connection is closed. + /// This is useful for performing cleanup or notification tasks. + public nonisolated func onClose(_ closure: @escaping ((any Error)?) -> Void) { + Task { + await self.closeFuture.whenComplete { result in + switch result { + case .success: + closure(nil) + case .failure(let error): + closure(error) + } + } + } } /// Runs the Memcache connection. @@ -95,7 +146,7 @@ public actor MemcacheConnection: Service { return channel.eventLoop.makeCompletedFuture { try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(MemcacheRequestEncoder())) try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(MemcacheResponseDecoder())) - return try NIOAsyncChannel(synchronouslyWrapping: channel) + return try NIOAsyncChannel(wrappingChannelSynchronously: channel) } }.get() @@ -106,39 +157,41 @@ public actor MemcacheConnection: Service { requestContinuation: continuation ) - var iterator = channel.inboundStream.makeAsyncIterator() switch self.state { case .running(_, let channel, let requestStream, let requestContinuation): - for await (request, continuation) in requestStream { - do { - try await channel.outboundWriter.write(request) - let responseBuffer = try await iterator.next() - - if let response = responseBuffer { - continuation.resume(returning: response) - } else { - self.state = .finished - requestContinuation.finish() - continuation.resume(throwing: MemcacheError( - code: .connectionShutdown, - message: "The connection to the Memcache server was unexpectedly closed.", - cause: nil, - location: .here() - )) - } - } catch { - switch self.state { - case .running: - self.state = .finished - requestContinuation.finish() - continuation.resume(throwing: MemcacheError( - code: .connectionShutdown, - message: "The connection to the Memcache server has shut down while processing a request.", - cause: error, - location: .here() - )) - case .initial, .finished: - break + try await channel.executeThenClose { inbound, outbound in + var inboundIterator = inbound.makeAsyncIterator() + for await (request, continuation) in requestStream { + do { + try await outbound.write(request) + let responseBuffer = try await inboundIterator.next() + + if let response = responseBuffer { + continuation.resume(returning: response) + } else { + self.state = .finished + requestContinuation.finish() + continuation.resume(throwing: MemcacheError( + code: .connectionShutdown, + message: "The connection to the Memcache server was unexpectedly closed.", + cause: nil, + location: .here() + )) + } + } catch { + switch self.state { + case .running: + self.state = .finished + requestContinuation.finish() + continuation.resume(throwing: MemcacheError( + code: .connectionShutdown, + message: "The connection to the Memcache server has shut down while processing a request.", + cause: error, + location: .here() + )) + case .initial, .finished: + break + } } } }