diff --git a/Sources/NIOTLS/NIOTypedApplicationProtocolNegotiationHandler.swift b/Sources/NIOTLS/NIOTypedApplicationProtocolNegotiationHandler.swift index f8d0729d8a..c8262d6c0a 100644 --- a/Sources/NIOTLS/NIOTypedApplicationProtocolNegotiationHandler.swift +++ b/Sources/NIOTLS/NIOTypedApplicationProtocolNegotiationHandler.swift @@ -49,10 +49,14 @@ public final class NIOTypedApplicationProtocolNegotiationHandler> { - self.negotiatedPromise.futureResult + return self.negotiatedPromise.futureResult } - private let negotiatedPromise: EventLoopPromise> + private var negotiatedPromise: EventLoopPromise> { + precondition(self._negotiatedPromise != nil, "Tried to access the protocol negotiation result before the handler was added to a pipeline") + return self._negotiatedPromise! + } + private var _negotiatedPromise: EventLoopPromise>? private let completionHandler: (ALPNResult, Channel) -> EventLoopFuture> private var stateMachine = ProtocolNegotiationHandlerStateMachine>() @@ -63,9 +67,8 @@ public final class NIOTypedApplicationProtocolNegotiationHandler EventLoopFuture>) { + public init(alpnCompleteHandler: @escaping (ALPNResult, Channel) -> EventLoopFuture>) { self.completionHandler = alpnCompleteHandler - self.negotiatedPromise = eventLoop.makePromise(of: NIOProtocolNegotiationResult.self) } /// Create an `ApplicationProtocolNegotiationHandler` with the given completion @@ -74,14 +77,18 @@ public final class NIOTypedApplicationProtocolNegotiationHandler EventLoopFuture>) { - self.init(eventLoop: eventLoop) { result, _ in + public convenience init(alpnCompleteHandler: @escaping (ALPNResult) -> EventLoopFuture>) { + self.init { result, _ in alpnCompleteHandler(result) } } - deinit { - switch self.stateMachine.deinitHandler() { + public func handlerAdded(context: ChannelHandlerContext) { + self._negotiatedPromise = context.eventLoop.makePromise() + } + + public func handlerRemoved(context: ChannelHandlerContext) { + switch self.stateMachine.handlerRemoved() { case .failPromise: self.negotiatedPromise.fail(ChannelError.inappropriateOperationForState) diff --git a/Sources/NIOTLS/ProtocolNegotiationHandlerStateMachine.swift b/Sources/NIOTLS/ProtocolNegotiationHandlerStateMachine.swift index 380a678529..6f00b02a8c 100644 --- a/Sources/NIOTLS/ProtocolNegotiationHandlerStateMachine.swift +++ b/Sources/NIOTLS/ProtocolNegotiationHandlerStateMachine.swift @@ -30,21 +30,16 @@ struct ProtocolNegotiationHandlerStateMachine { private var state = State.initial @usableFromInline - enum DeinitHandlerAction { + enum HandlerRemovedAction { case failPromise } @inlinable - mutating func deinitHandler() -> DeinitHandlerAction? { + mutating func handlerRemoved() -> HandlerRemovedAction? { switch self.state { - case .initial: + case .initial, .waitingForUser, .unbuffering: return .failPromise - case .waitingForUser, .unbuffering: - // We are retaining the handler strongly while waiting and unbuffering - // so we should never hit the deinit. - fatalError("Unexpected state") - case .finished: return .none } diff --git a/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift b/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift index 44843b558e..7e9061b88d 100644 --- a/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift +++ b/Tests/NIOPosixTests/AsyncChannelBootstrapTests.swift @@ -989,7 +989,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder())) try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder())) try channel.pipeline.syncOperations.addHandler(TLSUserEventHandler(proposedALPN: proposedOuterALPN)) - let negotiationHandler = NIOTypedApplicationProtocolNegotiationHandler(eventLoop: channel.eventLoop) { alpnResult, channel in + let negotiationHandler = NIOTypedApplicationProtocolNegotiationHandler { alpnResult, channel in switch alpnResult { case .negotiated(let alpn): switch alpn { @@ -1020,7 +1020,7 @@ final class AsyncChannelBootstrapTests: XCTestCase { @discardableResult private func addTypedApplicationProtocolNegotiationHandler(to channel: Channel) throws -> EventLoopFuture> { - let negotiationHandler = NIOTypedApplicationProtocolNegotiationHandler(eventLoop: channel.eventLoop) { alpnResult, channel in + let negotiationHandler = NIOTypedApplicationProtocolNegotiationHandler { alpnResult, channel in switch alpnResult { case .negotiated(let alpn): switch alpn { diff --git a/Tests/NIOTLSTests/NIOTypedApplicationProtocolNegotiationHandlerTests.swift b/Tests/NIOTLSTests/NIOTypedApplicationProtocolNegotiationHandlerTests.swift index 6218e50f4d..99bc3e795d 100644 --- a/Tests/NIOTLSTests/NIOTypedApplicationProtocolNegotiationHandlerTests.swift +++ b/Tests/NIOTLSTests/NIOTypedApplicationProtocolNegotiationHandlerTests.swift @@ -27,16 +27,15 @@ final class NIOTypedApplicationProtocolNegotiationHandlerTests: XCTestCase { private let negotiatedEvent: TLSUserEvent = .handshakeCompleted(negotiatedProtocol: "h2") private let negotiatedResult: ALPNResult = .negotiated("h2") - func testPromiseIsCompleted() { + func testPromiseIsCompleted() throws { let channel = EmbeddedChannel() - let eventLoop = channel.embeddedEventLoop - var handler: NIOTypedApplicationProtocolNegotiationHandler? = NIOTypedApplicationProtocolNegotiationHandler(eventLoop: eventLoop) { result, channel in + let handler = NIOTypedApplicationProtocolNegotiationHandler { result, channel in return channel.eventLoop.makeSucceededFuture(.init(result: (.negotiated(result)))) } - let future = handler!.protocolNegotiationResult - handler = nil - XCTAssertThrowsError(try future.wait()) { error in + try channel.pipeline.addHandler(handler).wait() + try channel.pipeline.removeHandler(handler).wait() + XCTAssertThrowsError(try handler.protocolNegotiationResult.wait()) { error in XCTAssertEqual(error as? ChannelError, .inappropriateOperationForState) } } @@ -46,7 +45,7 @@ final class NIOTypedApplicationProtocolNegotiationHandlerTests: XCTestCase { let loop = emChannel.eventLoop as! EmbeddedEventLoop var called = false - let handler = NIOTypedApplicationProtocolNegotiationHandler(eventLoop: loop) { result, channel in + let handler = NIOTypedApplicationProtocolNegotiationHandler { result, channel in called = true XCTAssertEqual(result, self.negotiatedResult) XCTAssertTrue(emChannel === channel) @@ -64,7 +63,7 @@ final class NIOTypedApplicationProtocolNegotiationHandlerTests: XCTestCase { let channel = EmbeddedChannel() let loop = channel.eventLoop as! EmbeddedEventLoop - let handler = NIOTypedApplicationProtocolNegotiationHandler(eventLoop: loop) { result in + let handler = NIOTypedApplicationProtocolNegotiationHandler { result in XCTFail("Negotiation fired") return loop.makeSucceededFuture(.init(result: (.failed))) } @@ -85,7 +84,7 @@ final class NIOTypedApplicationProtocolNegotiationHandlerTests: XCTestCase { let channel = EmbeddedChannel() let loop = channel.eventLoop as! EmbeddedEventLoop - let handler = NIOTypedApplicationProtocolNegotiationHandler(eventLoop: loop) { result in + let handler = NIOTypedApplicationProtocolNegotiationHandler { result in XCTFail("Should not be called") return loop.makeSucceededFuture(.init(result: (.failed))) } @@ -104,7 +103,7 @@ final class NIOTypedApplicationProtocolNegotiationHandlerTests: XCTestCase { let loop = channel.eventLoop as! EmbeddedEventLoop let continuePromise = loop.makePromise(of: NIOProtocolNegotiationResult.self) - let handler = NIOTypedApplicationProtocolNegotiationHandler(eventLoop: loop) { result in + let handler = NIOTypedApplicationProtocolNegotiationHandler { result in return continuePromise.futureResult } @@ -135,7 +134,7 @@ final class NIOTypedApplicationProtocolNegotiationHandlerTests: XCTestCase { let loop = channel.eventLoop as! EmbeddedEventLoop let continuePromise = loop.makePromise(of: NIOProtocolNegotiationResult.self) - let handler = NIOTypedApplicationProtocolNegotiationHandler(eventLoop: loop) { result in + let handler = NIOTypedApplicationProtocolNegotiationHandler { result in continuePromise.futureResult } let eventCounterHandler = EventCounterHandler() @@ -162,7 +161,7 @@ final class NIOTypedApplicationProtocolNegotiationHandlerTests: XCTestCase { let loop = channel.eventLoop as! EmbeddedEventLoop let continuePromise = loop.makePromise(of: NIOProtocolNegotiationResult.self) - let handler = NIOTypedApplicationProtocolNegotiationHandler(eventLoop: loop) { result in + let handler = NIOTypedApplicationProtocolNegotiationHandler { result in continuePromise.futureResult } let eventCounterHandler = EventCounterHandler() @@ -193,7 +192,7 @@ final class NIOTypedApplicationProtocolNegotiationHandlerTests: XCTestCase { let loop = channel.eventLoop as! EmbeddedEventLoop let continuePromise = loop.makePromise(of: NIOProtocolNegotiationResult.self) - let handler = NIOTypedApplicationProtocolNegotiationHandler(eventLoop: loop) { result in + let handler = NIOTypedApplicationProtocolNegotiationHandler { result in continuePromise.futureResult } let eventCounterHandler = EventCounterHandler()