diff --git a/Sources/NIOSSL/NIOSSLClientHandler.swift b/Sources/NIOSSL/NIOSSLClientHandler.swift index 24f19291..57fde8e4 100644 --- a/Sources/NIOSSL/NIOSSLClientHandler.swift +++ b/Sources/NIOSSL/NIOSSLClientHandler.swift @@ -31,6 +31,11 @@ private extension String { /// This handler can be used in channels that are acting as the client /// in the TLS dialog. For server connections, use the `NIOSSLServerHandler`. public final class NIOSSLClientHandler: NIOSSLHandler { + public convenience init(context: NIOSSLContext, serverHostname: String?) throws { + try self.init(context: context, serverHostname: serverHostname, optionalCustomVerificationCallback: nil) + } + + @available(*, deprecated, renamed: "init(context:serverHostname:customVerificationCallback:)") public init(context: NIOSSLContext, serverHostname: String?, verificationCallback: NIOSSLVerificationCallback? = nil) throws { guard let connection = context.createConnection() else { throw NIOSSLError.unableToAllocateBoringSSLObject @@ -39,7 +44,7 @@ public final class NIOSSLClientHandler: NIOSSLHandler { connection.setConnectState() if let serverHostname = serverHostname { if serverHostname.isIPAddress() { - throw BoringSSLError.invalidSNIName([]) + throw NIOSSLExtraError.cannotUseIPAddressInSNI(ipAddress: serverHostname) } // IP addresses must not be provided in the SNI extension, so filter them. @@ -52,4 +57,32 @@ public final class NIOSSLClientHandler: NIOSSLHandler { super.init(connection: connection, shutdownTimeout: context.configuration.shutdownTimeout) } + + + public convenience init(context: NIOSSLContext, serverHostname: String?, customVerificationCallback: @escaping NIOSSLCustomVerificationCallback) throws { + try self.init(context: context, serverHostname: serverHostname, optionalCustomVerificationCallback: customVerificationCallback) + } + + // This exists to handle the explosion of initializers we got when I tried to deprecate the first one. At least they all pass through one path now. + private init(context: NIOSSLContext, serverHostname: String?, optionalCustomVerificationCallback: NIOSSLCustomVerificationCallback?) throws { + guard let connection = context.createConnection() else { + throw NIOSSLError.unableToAllocateBoringSSLObject + } + + connection.setConnectState() + if let serverHostname = serverHostname { + if serverHostname.isIPAddress() { + throw NIOSSLExtraError.cannotUseIPAddressInSNI(ipAddress: serverHostname) + } + + // IP addresses must not be provided in the SNI extension, so filter them. + try connection.setServerName(name: serverHostname) + } + + if let verificationCallback = optionalCustomVerificationCallback { + connection.setCustomVerificationCallback(CustomVerifyManager(callback: verificationCallback)) + } + + super.init(connection: connection, shutdownTimeout: context.configuration.shutdownTimeout) + } } diff --git a/Sources/NIOSSL/SSLCallbacks.swift b/Sources/NIOSSL/SSLCallbacks.swift index e283d8ee..dbc299e7 100644 --- a/Sources/NIOSSL/SSLCallbacks.swift +++ b/Sources/NIOSSL/SSLCallbacks.swift @@ -58,9 +58,34 @@ public enum NIOSSLVerificationResult { /// or to store the `NIOSSLCertificate` somewhere for later consumption. The easiest way to be sure that the /// `NIOSSLCertificate` is safe to consume is to wait for a user event that shows the handshake as completed, /// or for channelInactive. +/// +/// warning: This callback uses the old-style OpenSSL callback behaviour and is excessively complex to program with. +/// Instead, prefer using the NIOSSLCustomVerificationCallback style which receives the entire trust chain at once, +/// and also supports asynchronous certificate verification. public typealias NIOSSLVerificationCallback = (NIOSSLVerificationResult, NIOSSLCertificate) -> NIOSSLVerificationResult +/// A custom verification callback that allows completely overriding the certificate verification logic of BoringSSL. +/// +/// This verification callback is called no more than once per connection attempt. It is invoked with two arguments: +/// +/// 1. The certificate chain presented by the peer, in the order the peer presented them (with the first certificate +/// being the leaf certificate presented by the peer). +/// 2. An `EventLoopPromise` that must be completed to signal the result of the verification. +/// +/// Please be cautious with calling out from this method. This method is always invoked on the event loop, +/// so you must not block or wait. However, you may perform asynchronous work by leaving the event loop context: +/// when the verification is complete you must complete the provided `EventLoopPromise`. +/// +/// This method must take care to ensure that it does not cause any `ChannelHandler` to recursively call back into +/// the `NIOSSLHandler` that triggered it, as making re-entrant calls into BoringSSL is not supported by SwiftNIO and +/// leads to undefined behaviour. It is acceptable to leave the event loop context and then call into the `NIOSSLHandler`, +/// as this will not be re-entrant. +/// +/// Note that setting this callback will override _all_ verification logic that BoringSSL provides. +public typealias NIOSSLCustomVerificationCallback = ([NIOSSLCertificate], EventLoopPromise) -> Void + + /// A callback that can be used to implement `SSLKEYLOGFILE` support. /// /// Wireshark can decrypt packet captures that contain encrypted TLS connections if they have access to the @@ -108,3 +133,117 @@ extension KeyLogCallbackManager { self.callback(self.scratchBuffer) } } + + +/// A struct that provides helpers for working with a NIOSSLCustomVerificationCallback. +internal struct CustomVerifyManager { + private var callback: CallbackType + + private var result: PendingResult = .notStarted + + init(callback: @escaping NIOSSLCustomVerificationCallback) { + self.callback = .public(callback) + } + + init(callback: @escaping InternalCallback) { + self.callback = .internal(callback) + } +} + + +extension CustomVerifyManager { + private enum PendingResult: Hashable { + case notStarted + + case pendingResult + + case complete(NIOSSLVerificationResult) + } +} + + +extension CustomVerifyManager { + mutating func process(on connection: SSLConnection) -> ssl_verify_result_t { + // First, check if we have a result. + switch self.result { + case .complete(.certificateVerified): + return ssl_verify_ok + case .complete(.failed): + return ssl_verify_invalid + case .pendingResult: + // Ask me again. + return ssl_verify_retry + case .notStarted: + // The rest of this method handles this case. + break + } + + self.result = .pendingResult + + // Ok, no result. We need a promise for the user to use to supply a result. + guard let eventLoop = connection.eventLoop else { + // No event loop. We cannot possibly be negotiating here. + preconditionFailure("No event loop present") + } + + let promise = eventLoop.makePromise(of: NIOSSLVerificationResult.self) + + // We need to attach our "do the thing" callback. This will always invoke the "ask me again" API, and it will do so in a separate + // event loop tick to avoid awkward re-entrancy with this method. + promise.futureResult.whenComplete { result in + // When we complete here we need to set our result state, and then ask to respin certificate verification. + // If we can't respin verification because we've dropped the parent handler, that's fine, no harm no foul. + // For that reason, we tolerate both the verify manager and the parent handler being nil. + eventLoop.execute { + // Note that we don't close over self here: that's to deal with the fact that this is a struct, and we don't want to + // escape the mutable ownership of self. + precondition(connection.customVerificationManager == nil || connection.customVerificationManager?.result == .some(.pendingResult)) + connection.customVerificationManager?.result = .complete(NIOSSLVerificationResult(result)) + connection.parentHandler?.asynchronousCertificateVerificationComplete() + } + } + + // Ok, let's do it. + self.callback.invoke(on: connection, promise: promise) + return ssl_verify_retry + } +} + + +extension CustomVerifyManager { + private enum CallbackType { + case `public`(NIOSSLCustomVerificationCallback) + case `internal`(InternalCallback) + + /// For user-supplied callbacks we need to give them public types. For internal ones, we just pass the + /// `EventLoopPromise` object through. + func invoke(on connection: SSLConnection, promise: EventLoopPromise) { + switch self { + case .public(let publicCallback): + do { + let certificates = try connection.peerCertificateChain() + publicCallback(certificates, promise) + } catch { + promise.fail(error) + } + + case .internal(let internalCallback): + internalCallback(promise) + } + } + } + + internal typealias InternalCallback = (EventLoopPromise) -> Void +} + + +extension NIOSSLVerificationResult { + init(_ result: Result) { + switch result { + case .success(let s): + self = s + case .failure: + self = .failed + } + } +} diff --git a/Sources/NIOSSL/SSLCertificate.swift b/Sources/NIOSSL/SSLCertificate.swift index dc033c54..813959a3 100644 --- a/Sources/NIOSSL/SSLCertificate.swift +++ b/Sources/NIOSSL/SSLCertificate.swift @@ -47,7 +47,7 @@ public class NIOSSLCertificate { case ipv6(in6_addr) } - private init(withReference ref: UnsafeMutablePointer) { + private init(withOwnedReference ref: UnsafeMutablePointer) { self._ref = UnsafeMutableRawPointer(ref) // erasing the type for @_implementationOnly import CNIOBoringSSL } @@ -73,7 +73,7 @@ public class NIOSSLCertificate { throw NIOSSLError.failedToLoadCertificate } - self.init(withReference: x509!) + self.init(withOwnedReference: x509!) } /// Create a NIOSSLCertificate from a buffer of bytes in either PEM or @@ -107,7 +107,36 @@ public class NIOSSLCertificate { throw NIOSSLError.failedToLoadCertificate } - self.init(withReference: ref!) + self.init(withOwnedReference: ref!) + } + + /// Create a NIOSSLCertificate from a buffer of bytes in either PEM or DER format. + internal convenience init(bytes ptr: UnsafeRawBufferPointer, format: NIOSSLSerializationFormats) throws { + // TODO(cory): + // The body of this method is exactly identical to the initializer above, except for the "withUnsafeBytes" call. + // ContiguousBytes would have been the lowest effort way to reduce this duplication, but we can't use it without + // bringing Foundation in. Probably we should use Sequence where Element == UInt8 and the withUnsafeContiguousBytesIfAvailable + // method, but that's a much more substantial refactor. Let's do it later. + let bio = CNIOBoringSSL_BIO_new_mem_buf(ptr.baseAddress, CInt(ptr.count))! + + defer { + CNIOBoringSSL_BIO_free(bio) + } + + let ref: UnsafeMutablePointer? + + switch format { + case .pem: + ref = CNIOBoringSSL_PEM_read_bio_X509(bio, nil, nil, nil) + case .der: + ref = CNIOBoringSSL_d2i_X509_bio(bio, nil) + } + + if ref == nil { + throw NIOSSLError.failedToLoadCertificate + } + + self.init(withOwnedReference: ref!) } /// Create a NIOSSLCertificate wrapping a pointer into BoringSSL. @@ -121,7 +150,7 @@ public class NIOSSLCertificate { /// In general, however, this function should be avoided in favour of one of the convenience /// initializers, which ensure that the lifetime of the `X509` object is better-managed. static func fromUnsafePointer(takingOwnership pointer: UnsafeMutablePointer) -> NIOSSLCertificate { - return NIOSSLCertificate(withReference: pointer) + return NIOSSLCertificate(withOwnedReference: pointer) } /// Get a sequence of the alternative names in the certificate. @@ -267,10 +296,10 @@ extension NIOSSLCertificate { throw NIOSSLError.failedToLoadCertificate } - var certificates = [NIOSSLCertificate(withReference: x509)] + var certificates = [NIOSSLCertificate(withOwnedReference: x509)] while let x = CNIOBoringSSL_PEM_read_bio_X509(bio, nil, nil, nil) { - certificates.append(.init(withReference: x)) + certificates.append(.init(withOwnedReference: x)) } let err = CNIOBoringSSL_ERR_peek_error() diff --git a/Sources/NIOSSL/SSLConnection.swift b/Sources/NIOSSL/SSLConnection.swift index 1429573b..ba42a13b 100644 --- a/Sources/NIOSSL/SSLConnection.swift +++ b/Sources/NIOSSL/SSLConnection.swift @@ -48,13 +48,15 @@ internal final class SSLConnection { private let ssl: OpaquePointer private let parentContext: NIOSSLContext private var bio: ByteBufferBIO? - private var verificationCallback: NIOSSLVerificationCallback? - internal var platformVerificationState: PlatformVerificationState = PlatformVerificationState() internal var expectedHostname: String? internal var role: ConnectionRole? internal var parentHandler: NIOSSLHandler? internal var eventLoop: EventLoop? + /// Deprecated in favour of customVerificationManager + private var verificationCallback: NIOSSLVerificationCallback? + internal var customVerificationManager: CustomVerifyManager? + /// Whether certificate hostnames should be validated. var validateHostnames: Bool { if case .fullVerification = parentContext.configuration.certificateVerification { @@ -116,7 +118,9 @@ internal final class SSLConnection { self.expectedHostname = name } - /// Sets the BoringSSL verification callback. + /// Sets the BoringSSL old-style verification callback. + /// + /// This is deprecated in favour of the new-style verification callback in SSLContext. func setVerificationCallback(_ callback: @escaping NIOSSLVerificationCallback) { // Store the verification callback. We need to do this to keep it alive throughout the connection. // We'll drop this when we're told that it's no longer needed to ensure we break the reference cycles @@ -156,6 +160,33 @@ internal final class SSLConnection { } } + func setCustomVerificationCallback(_ callbackManager: CustomVerifyManager) { + // Store the verification callback. We need to do this to keep it alive throughout the connection. + // We'll drop this when we're told that it's no longer needed to ensure we break the reference cycles + // that this callback inevitably produces. + assert(self.customVerificationManager == nil) + self.customVerificationManager = callbackManager + + // We need to know what the current mode is. + let currentMode = CNIOBoringSSL_SSL_get_verify_mode(self.ssl) + CNIOBoringSSL_SSL_set_custom_verify(self.ssl, currentMode) { ssl, outAlert in + guard let unwrappedSSL = ssl else { + preconditionFailure("Unexpected null pointer in custom verification callback. ssl: \(String(describing: ssl))") + } + + // Ok, this call may be a resumption of a previous negotiation. We need to check if our connection object has a pre-existing verifiation state. + guard let connectionPointer = CNIOBoringSSL_SSL_get_ex_data(unwrappedSSL, sslConnectionExDataIndex) else { + // Uh-ok, our application state is gone. Don't let this error silently pass, go bang. + preconditionFailure("Unable to find application data from SSL * \(unwrappedSSL), index \(sslConnectionExDataIndex)") + } + + let connection = Unmanaged.fromOpaque(connectionPointer).takeUnretainedValue() + + // We force unwrap the custom verification manager because for it to not be set is a programmer error. + return connection.customVerificationManager!.process(on: connection) + } + } + /// Sets whether renegotiation is supported. func setRenegotiationSupport(_ state: NIORenegotiationSupport) { var baseState: ssl_renegotiate_mode_t @@ -364,9 +395,10 @@ internal final class SSLConnection { /// Must only be called when the connection is no longer needed. The rest of this object /// preconditions on that being true, so we'll find out quickly when that's not the case. func close() { - /// Drop the verification callback. This breaks any reference cycles that are inevitably - /// created by this callback. + /// Drop the verification callbacks. This breaks any reference cycles that are inevitably + /// created by these callbacks. self.verificationCallback = nil + self.customVerificationManager = nil // Also drop the reference to the parent channel handler, which is a trivial reference cycle. self.parentHandler = nil @@ -422,6 +454,17 @@ extension SSLConnection { return try body(PeerCertificateChainBuffers(basePointer: stackPointer)) } + + /// The certificate chain presented by the peer. + func peerCertificateChain() throws -> [NIOSSLCertificate] { + return try self.withPeerCertificateChainBuffers { buffers in + guard let buffers = buffers else { + return [] + } + + return try buffers.map { try NIOSSLCertificate(bytes: $0, format: .der) } + } + } } extension SSLConnection.PeerCertificateChainBuffers: RandomAccessCollection { diff --git a/Sources/NIOSSL/SSLContext.swift b/Sources/NIOSSL/SSLContext.swift index bcdcc773..8c32d2b0 100644 --- a/Sources/NIOSSL/SSLContext.swift +++ b/Sources/NIOSSL/SSLContext.swift @@ -246,7 +246,20 @@ public final class NIOSSLContext { guard let ssl = CNIOBoringSSL_SSL_new(self.sslContext) else { return nil } - return SSLConnection(ownedSSL: ssl, parentContext: self) + + let conn = SSLConnection(ownedSSL: ssl, parentContext: self) + + // If we need to turn on the validation on Apple platforms, do it here. + #if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) + switch self.configuration.trustRoots { + case .some(.default), .none: + conn.setCustomVerificationCallback(CustomVerifyManager(callback: conn.performSecurityFrameworkValidation(promise:))) + case .some(.certificates), .some(.file): + break + } + #endif + + return conn } fileprivate func alpnSelectCallback(offeredProtocols: UnsafeBufferPointer) -> (index: Int, length: Int)? { @@ -397,8 +410,9 @@ extension NIOSSLContext { } private static func platformDefaultConfiguration(context: OpaquePointer) throws { - // Platform default trust is configured differently in different places. On Darwin we invoke Security.framework in a custom callback. + // Platform default trust is configured differently in different places. // On Linux, we use our searched heuristics to guess about where the platform trust store is. + // On Darwin, we use a custom callback that is set later, in createConnection #if os(Linux) let result = rootCAFilePath.withCString { rootCAFilePointer in rootCADirectoryPath.withCString { rootCADirectoryPointer in @@ -410,8 +424,6 @@ extension NIOSSLContext { let errorStack = BoringSSLError.buildErrorStack() throw BoringSSLError.unknownError(errorStack) } - #elseif os(macOS) || os(iOS) || os(watchOS) || os(tvOS) - CNIOBoringSSL_SSL_CTX_set_custom_verify(context, SSL_VERIFY_PEER, securityFrameworkCustomVerify) #endif } diff --git a/Sources/NIOSSL/SSLErrors.swift b/Sources/NIOSSL/SSLErrors.swift index 3e4474ff..c1664a05 100644 --- a/Sources/NIOSSL/SSLErrors.swift +++ b/Sources/NIOSSL/SSLErrors.swift @@ -210,6 +210,7 @@ extension NIOSSLExtraError { private enum BaseError: Equatable { case failedToValidateHostname case serverHostnameImpossibleToMatch + case cannotUseIPAddressInSNI } } @@ -221,15 +222,26 @@ extension NIOSSLExtraError { /// The server hostname provided by the user cannot match any names in the certificate due to containing invalid characters. public static let serverHostnameImpossibleToMatch = NIOSSLExtraError(baseError: .serverHostnameImpossibleToMatch, description: nil) + /// IP addresses may not be used in SNI. + public static let cannotUseIPAddressInSNI = NIOSSLExtraError(baseError: .cannotUseIPAddressInSNI, description: nil) + + @inline(never) internal static func failedToValidateHostname(expectedName: String) -> NIOSSLExtraError { let description = "Couldn't find \(expectedName) in certificate from peer" return NIOSSLExtraError(baseError: .failedToValidateHostname, description: description) } + @inline(never) internal static func serverHostnameImpossibleToMatch(hostname: String) -> NIOSSLExtraError { let description = "The server hostname \(hostname) cannot be matched due to containing non-DNS characters" return NIOSSLExtraError(baseError: .serverHostnameImpossibleToMatch, description: description) } + + @inline(never) + internal static func cannotUseIPAddressInSNI(ipAddress: String) -> NIOSSLExtraError { + let description = "IP addresses cannot validly be used for Server Name Indication, got \(ipAddress)" + return NIOSSLExtraError(baseError: .cannotUseIPAddressInSNI, description: description) + } } diff --git a/Sources/NIOSSL/SecurityFrameworkCertificateVerification.swift b/Sources/NIOSSL/SecurityFrameworkCertificateVerification.swift index 0069207e..658fec87 100644 --- a/Sources/NIOSSL/SecurityFrameworkCertificateVerification.swift +++ b/Sources/NIOSSL/SecurityFrameworkCertificateVerification.swift @@ -11,6 +11,7 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// +import NIO #if compiler(>=5.1) && compiler(<5.2) @_implementationOnly import CNIOBoringSSL @@ -18,126 +19,59 @@ import CNIOBoringSSL #endif -/// The current state of the platform verification helper, if one is in use. -/// -/// Only used on Apple platforms currently. -internal struct PlatformVerificationState { - #if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) - fileprivate var state: SecurityFrameworkVerificationState? = nil - #endif -} - // We can only use Security.framework to validate TLS certificates on Apple platforms. #if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) import Dispatch import Foundation import Security -/// A custom certificate verification function for BoringSSL that uses Security.framework to provide certificate verification. -/// -/// - parameters: -/// - ssl: The pointer to the SSL * object for this connection. -/// - outAlert: A C-style inout parameter that contains a pointer to an alert. Is assumed to be non-null. -internal func securityFrameworkCustomVerify(_ ssl: OpaquePointer?, _ outAlert: UnsafeMutablePointer?) -> ssl_verify_result_t { - guard let unwrappedSSL = ssl, let unwrappedOutAlert = outAlert else { - preconditionFailure("Unexpected null pointer in custom verification callback. ssl: \(String(describing: ssl)) outAlert: \(String(describing: outAlert))") - } - - // Ok, this call may be a resumption of a previous negotiation. We need to check if our connection object has a pre-existing verifiation state. - guard let connectionPointer = CNIOBoringSSL_SSL_get_ex_data(unwrappedSSL, sslConnectionExDataIndex) else { - // Uh-ok, our application state is gone. Don't let this error silently pass, go bang. - preconditionFailure("Unable to find application data from SSL * \(unwrappedSSL), index \(sslConnectionExDataIndex)") - } - - let connection = Unmanaged.fromOpaque(connectionPointer).takeUnretainedValue() - - do { - return try connection.performSecurityFrameworkValidation() - } catch { - unwrappedOutAlert.pointee = UInt8(SSL_AD_INTERNAL_ERROR) - return ssl_verify_invalid - } -} - -extension PlatformVerificationState { - fileprivate enum SecurityFrameworkVerificationState { - case pendingResult - - case complete(SecTrustResultType) - } -} - extension SSLConnection { - func performSecurityFrameworkValidation() throws -> ssl_verify_result_t { - // First, check whether we have an outstanding or completed query. If we do, don't do any other work. - switch self.platformVerificationState.state { - case .some(.complete(.proceed)), .some(.complete(.unspecified)): - // These two cases mean we have successfully validated the certificate. We're done! Wipe out the state so - // that if we need to reverify we can, and return the success. - self.platformVerificationState.state = nil - return ssl_verify_ok - case .some(.complete): - // Ok, this broader case means we failed. We're still done, but return the failure instead. - self.platformVerificationState.state = nil - return ssl_verify_invalid - case .some(.pendingResult): - // We've got a validation attempt outstanding. Tell BoringSSL to hold its horses. - return ssl_verify_retry - case .none: - // No verification outstanding: do more work. - break - } + func performSecurityFrameworkValidation(promise: EventLoopPromise) { + do { + // Ok, time to kick off a validation. Let's get some certificate buffers. + let certificates: [SecCertificate] = try self.withPeerCertificateChainBuffers { buffers in + guard let buffers = buffers else { + throw NIOSSLError.unableToValidateCertificate + } + + return try buffers.map { buffer in + let data = Data(bytes: buffer.baseAddress!, count: buffer.count) + guard let cert = SecCertificateCreateWithData(nil, data as CFData) else { + throw NIOSSLError.unableToValidateCertificate + } + return cert + } + } - // Ok, time to kick off a validation. Let's get some certificate buffers. - let certificates: [SecCertificate] = try self.withPeerCertificateChainBuffers { buffers in - guard let buffers = buffers else { + // This force-unwrap is safe as we must have decided if we're a client or a server before validation. + var trust: SecTrust? = nil + var result: OSStatus + let policy = SecPolicyCreateSSL(self.role! == .client, self.expectedHostname as CFString?) + result = SecTrustCreateWithCertificates(certificates as CFArray, policy, &trust) + guard result == errSecSuccess, let actualTrust = trust else { throw NIOSSLError.unableToValidateCertificate } - return try buffers.map { buffer in - let data = Data(bytes: buffer.baseAddress!, count: buffer.count) - guard let cert = SecCertificateCreateWithData(nil, data as CFData) else { - throw NIOSSLError.unableToValidateCertificate + // We create a DispatchQueue here to be called back on, as this validation may perform network activity. + let callbackQueue = DispatchQueue(label: "io.swiftnio.ssl.validationCallbackQueue") + + result = SecTrustEvaluateAsync(actualTrust, callbackQueue) { (_, result) in + switch result { + case .proceed, .unspecified: + // These two cases mean we have successfully validated the certificate. We're done! + promise.succeed(.certificateVerified) + default: + // Oops, we failed. + promise.succeed(.failed) } - return cert } - } - // This force-unwrap is safe as we must have decided if we're a client or a server before validation. - var trust: SecTrust? = nil - var result: OSStatus - let policy = SecPolicyCreateSSL(self.role! == .client, self.expectedHostname as CFString?) - result = SecTrustCreateWithCertificates(certificates as CFArray, policy, &trust) - guard result == errSecSuccess, let actualTrust = trust else { - throw NIOSSLError.unableToValidateCertificate - } - - // We create a DispatchQueue here to be called back on, as this validation may perform network activity. - let callbackQueue = DispatchQueue(label: "io.swiftnio.ssl.validationCallbackQueue") - - // Now we need to grab some things we need in the callback block. Specifically, we need the parent handler - // and the event loop it belongs to. This is because we cannot safely access these things from inside the - // block, and we need the eventLoop to get back onto a safe thread. - // - // We don't hold these references weak because we are ok with keeping the handler alive longer than necessary - // in the rare case that the handler is removed before the callback completes. - // The force-unwrap is safe, as we cannot be midway through handshaking before the connection has become active. - let eventLoop = self.eventLoop! - - result = SecTrustEvaluateAsync(actualTrust, callbackQueue) { (_, result) in - // When we complete here we need to set our result state, and then ask to respin certificate verification. - // If we can't respin verification because we've dropped the parent handler, that's fine, no harm no foul. - eventLoop.execute { - self.platformVerificationState.state = .complete(result) - self.parentHandler?.asynchronousCertificateVerificationComplete() + guard result == errSecSuccess else { + throw NIOSSLError.unableToValidateCertificate } + } catch { + promise.fail(error) } - - guard result == errSecSuccess else { - throw NIOSSLError.unableToValidateCertificate - } - - return ssl_verify_retry } } diff --git a/Tests/NIOSSLTests/ClientSNITests.swift b/Tests/NIOSSLTests/ClientSNITests.swift index ca743a4b..a1e995b7 100644 --- a/Tests/NIOSSLTests/ClientSNITests.swift +++ b/Tests/NIOSSLTests/ClientSNITests.swift @@ -83,7 +83,7 @@ class ClientSNITests: XCTestCase { do { _ = try NIOSSLClientHandler(context: context, serverHostname: "192.168.0.1") XCTFail("Created client handler with invalid SNI name") - } catch BoringSSLError.invalidSNIName { + } catch let err as NIOSSLExtraError where err == NIOSSLExtraError.cannotUseIPAddressInSNI { // All fine. } } @@ -94,7 +94,7 @@ class ClientSNITests: XCTestCase { do { _ = try NIOSSLClientHandler(context: context, serverHostname: "fe80::200:f8ff:fe21:67cf") XCTFail("Created client handler with invalid SNI name") - } catch BoringSSLError.invalidSNIName { + } catch let err as NIOSSLExtraError where err == NIOSSLExtraError.cannotUseIPAddressInSNI { // All fine. } } diff --git a/Tests/NIOSSLTests/NIOSSLIntegrationTest+XCTest.swift b/Tests/NIOSSLTests/NIOSSLIntegrationTest+XCTest.swift index 046dfee8..87765dda 100644 --- a/Tests/NIOSSLTests/NIOSSLIntegrationTest+XCTest.swift +++ b/Tests/NIOSSLTests/NIOSSLIntegrationTest+XCTest.swift @@ -47,6 +47,10 @@ extension NIOSSLIntegrationTest { ("testFlushPendingReadsOnCloseNotify", testFlushPendingReadsOnCloseNotify), ("testForcingVerificationFailure", testForcingVerificationFailure), ("testExtractingCertificates", testExtractingCertificates), + ("testForcingVerificationFailureNewCallback", testForcingVerificationFailureNewCallback), + ("testErroringNewVerificationCallback", testErroringNewVerificationCallback), + ("testNewCallbackCanDelayHandshake", testNewCallbackCanDelayHandshake), + ("testExtractingCertificatesNewCallback", testExtractingCertificatesNewCallback), ("testRepeatedClosure", testRepeatedClosure), ("testClosureTimeout", testClosureTimeout), ("testReceivingGibberishAfterAttemptingToClose", testReceivingGibberishAfterAttemptingToClose), diff --git a/Tests/NIOSSLTests/NIOSSLIntegrationTest.swift b/Tests/NIOSSLTests/NIOSSLIntegrationTest.swift index 1ff76f11..1c74127c 100644 --- a/Tests/NIOSSLTests/NIOSSLIntegrationTest.swift +++ b/Tests/NIOSSLTests/NIOSSLIntegrationTest.swift @@ -18,6 +18,7 @@ import XCTest #else import CNIOBoringSSL #endif +import NIOConcurrencyHelpers import NIO @testable import NIOSSL import NIOTLS @@ -292,21 +293,70 @@ internal func clientTLSChannel(context: NIOSSLContext, group: EventLoopGroup, connectingTo: SocketAddress, serverHostname: String? = nil, - verificationCallback: NIOSSLVerificationCallback? = nil, file: StaticString = #file, line: UInt = #line) throws -> Channel { + func handlerFactory() throws -> NIOSSLClientHandler { + return try NIOSSLClientHandler(context: context, serverHostname: serverHostname) + } + + return try _clientTLSChannel(context: context, preHandlers: preHandlers, postHandlers: postHandlers, group: group, connectingTo: connectingTo, handlerFactory: handlerFactory) +} + + +@available(*, deprecated, renamed: "clientTLSChannel(context:preHandlers:postHandlers:group:connectingTo:serverHostname:customVerificationCallback:file:line:)") +internal func clientTLSChannel(context: NIOSSLContext, + preHandlers: [ChannelHandler], + postHandlers: [ChannelHandler], + group: EventLoopGroup, + connectingTo: SocketAddress, + serverHostname: String? = nil, + verificationCallback: @escaping NIOSSLVerificationCallback, + file: StaticString = #file, + line: UInt = #line) throws -> Channel { + func handlerFactory() throws -> NIOSSLClientHandler { + return try NIOSSLClientHandler(context: context, serverHostname: serverHostname, verificationCallback: verificationCallback) + } + + return try _clientTLSChannel(context: context, preHandlers: preHandlers, postHandlers: postHandlers, group: group, connectingTo: connectingTo, handlerFactory: handlerFactory) +} + + +internal func clientTLSChannel(context: NIOSSLContext, + preHandlers: [ChannelHandler], + postHandlers: [ChannelHandler], + group: EventLoopGroup, + connectingTo: SocketAddress, + serverHostname: String? = nil, + customVerificationCallback: @escaping NIOSSLCustomVerificationCallback, + file: StaticString = #file, + line: UInt = #line) throws -> Channel { + func handlerFactory() throws -> NIOSSLClientHandler { + return try NIOSSLClientHandler(context: context, serverHostname: serverHostname, customVerificationCallback: customVerificationCallback) + } + + return try _clientTLSChannel(context: context, preHandlers: preHandlers, postHandlers: postHandlers, group: group, connectingTo: connectingTo, handlerFactory: handlerFactory) +} + +fileprivate func _clientTLSChannel(context: NIOSSLContext, + preHandlers: [ChannelHandler], + postHandlers: [ChannelHandler], + group: EventLoopGroup, + connectingTo: SocketAddress, + handlerFactory: @escaping () throws -> NIOSSLClientHandler, + file: StaticString = #file, + line: UInt = #line) throws -> Channel { return try assertNoThrowWithValue(ClientBootstrap(group: group) .channelInitializer { channel in let results = preHandlers.map { channel.pipeline.addHandler($0) } return EventLoopFuture.andAllSucceed(results, on: results.first?.eventLoop ?? group.next()).flatMapThrowing { - try NIOSSLClientHandler(context: context, serverHostname: serverHostname, verificationCallback: verificationCallback) - }.flatMap { - channel.pipeline.addHandler($0) - }.flatMap { - let results = postHandlers.map { channel.pipeline.addHandler($0) } - return EventLoopFuture.andAllSucceed(results, on: results.first?.eventLoop ?? group.next()) - } - }.connect(to: connectingTo).wait(), file: file, line: line) + try handlerFactory() + }.flatMap { + channel.pipeline.addHandler($0) + }.flatMap { + let results = postHandlers.map { channel.pipeline.addHandler($0) } + return EventLoopFuture.andAllSucceed(results, on: results.first?.eventLoop ?? group.next()) + } + }.connect(to: connectingTo).wait(), file: file, line: line) } class NIOSSLIntegrationTest: XCTestCase { @@ -1161,6 +1211,7 @@ class NIOSSLIntegrationTest: XCTestCase { XCTAssertEqual(readData.readString(length: readData.readableBytes)!, "Hello") } + @available(*, deprecated, message: "Testing deprecated API surface") func testForcingVerificationFailure() throws { let context = try configuredSSLContext() @@ -1209,6 +1260,7 @@ class NIOSSLIntegrationTest: XCTestCase { } } + @available(*, deprecated, message: "Testing deprecated API surface") func testExtractingCertificates() throws { let context = try configuredSSLContext() @@ -1249,6 +1301,206 @@ class NIOSSLIntegrationTest: XCTestCase { XCTAssertEqual(certificates.count, 1) } + func testForcingVerificationFailureNewCallback() throws { + let context = try configuredSSLContext() + + let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + + let serverChannel: Channel = try serverTLSChannel(context: context, handlers: [], group: group) + defer { + XCTAssertNoThrow(try serverChannel.close().wait()) + } + + let errorHandler = ErrorCatcher() + let clientChannel = try clientTLSChannel(context: try configuredClientContext(), + preHandlers: [], + postHandlers: [errorHandler], + group: group, + connectingTo: serverChannel.localAddress!, + customVerificationCallback: { _, promise in + promise.succeed(.failed) + }) + + var originalBuffer = clientChannel.allocator.buffer(capacity: 5) + originalBuffer.writeString("Hello") + let writeFuture = clientChannel.writeAndFlush(originalBuffer) + let errorsFuture: EventLoopFuture<[NIOSSLError]> = writeFuture.recover { (_: Error) in + // We're swallowing errors here, on purpose, because we'll definitely + // hit them. + return () + }.map { + return errorHandler.errors + } + let actualErrors = try errorsFuture.wait() + + // This write will have failed, but that's fine: we just want it as a signal that + // the handshake is done so we can make our assertions. + XCTAssertEqual(actualErrors.count, 1) + switch actualErrors.first! { + case .handshakeFailed: + // expected + break + case let error: + XCTFail("Unexpected error: \(error)") + } + } + + func testErroringNewVerificationCallback() throws { + enum LocalError: Error { + case kaboom + } + + let context = try configuredSSLContext() + + let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + + let serverChannel: Channel = try serverTLSChannel(context: context, handlers: [], group: group) + defer { + XCTAssertNoThrow(try serverChannel.close().wait()) + } + + let errorHandler = ErrorCatcher() + let clientChannel = try clientTLSChannel(context: try configuredClientContext(), + preHandlers: [], + postHandlers: [errorHandler], + group: group, + connectingTo: serverChannel.localAddress!, + customVerificationCallback: { _, promise in + promise.fail(LocalError.kaboom) + }) + + var originalBuffer = clientChannel.allocator.buffer(capacity: 5) + originalBuffer.writeString("Hello") + let writeFuture = clientChannel.writeAndFlush(originalBuffer) + let errorsFuture: EventLoopFuture<[NIOSSLError]> = writeFuture.recover { (_: Error) in + // We're swallowing errors here, on purpose, because we'll definitely + // hit them. + return () + }.map { + return errorHandler.errors + } + let actualErrors = try errorsFuture.wait() + + // This write will have failed, but that's fine: we just want it as a signal that + // the handshake is done so we can make our assertions. + XCTAssertEqual(actualErrors.count, 1) + switch actualErrors.first! { + case .handshakeFailed: + // expected + break + case let error: + XCTFail("Unexpected error: \(error)") + } + } + + func testNewCallbackCanDelayHandshake() throws { + let context = try configuredSSLContext() + + let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + + + var completionPromiseFired: Bool = false + let completionPromiseFiredLock = Lock() + + let completionPromise: EventLoopPromise = group.next().makePromise() + completionPromise.futureResult.whenComplete { _ in + completionPromiseFiredLock.withLock { + completionPromiseFired = true + } + } + + + let serverChannel: Channel = try serverTLSChannel(context: context, handlers: [SimpleEchoServer()], group: group) + defer { + XCTAssertNoThrow(try serverChannel.close().wait()) + } + + var handshakeCompletePromise: EventLoopPromise? = nil + let handshakeFiredPromise: EventLoopPromise = group.next().makePromise() + + let clientChannel = try clientTLSChannel(context: configuredClientContext(), + preHandlers: [], + postHandlers: [PromiseOnReadHandler(promise: completionPromise)], + group: group, + connectingTo: serverChannel.localAddress!, + serverHostname: "localhost", + customVerificationCallback: { innerCertificates, promise in + handshakeCompletePromise = promise + handshakeFiredPromise.succeed(()) + }) + defer { + XCTAssertNoThrow(try clientChannel.close().wait()) + } + + var originalBuffer = clientChannel.allocator.buffer(capacity: 5) + originalBuffer.writeString("Hello") + clientChannel.writeAndFlush(originalBuffer, promise: nil) + + // This has driven the handshake to begin, so we can wait for that. + XCTAssertNoThrow(try handshakeFiredPromise.futureResult.wait()) + + // We can now check whether the completion promise has fired: it should not have. + completionPromiseFiredLock.withLock { + XCTAssertFalse(completionPromiseFired) + } + + // Ok, allow the handshake to run. + handshakeCompletePromise!.succeed(.certificateVerified) + + let newBuffer = try completionPromise.futureResult.wait() + XCTAssertTrue(completionPromiseFired) + XCTAssertEqual(newBuffer, originalBuffer) + } + + func testExtractingCertificatesNewCallback() throws { + let context = try configuredSSLContext() + + let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + + let completionPromise: EventLoopPromise = group.next().makePromise() + + let serverChannel: Channel = try serverTLSChannel(context: context, handlers: [SimpleEchoServer()], group: group) + defer { + XCTAssertNoThrow(try serverChannel.close().wait()) + } + + var certificates = [NIOSSLCertificate]() + let clientChannel = try clientTLSChannel(context: configuredClientContext(), + preHandlers: [], + postHandlers: [PromiseOnReadHandler(promise: completionPromise)], + group: group, + connectingTo: serverChannel.localAddress!, + serverHostname: "localhost", + customVerificationCallback: { innerCertificates, promise in + certificates = innerCertificates + promise.succeed(.certificateVerified) + }) + defer { + XCTAssertNoThrow(try clientChannel.close().wait()) + } + + var originalBuffer = clientChannel.allocator.buffer(capacity: 5) + originalBuffer.writeString("Hello") + XCTAssertNoThrow(try clientChannel.writeAndFlush(originalBuffer).wait()) + + let newBuffer = try completionPromise.futureResult.wait() + XCTAssertEqual(newBuffer, originalBuffer) + + XCTAssertEqual(certificates, [NIOSSLIntegrationTest.cert]) + } + func testRepeatedClosure() throws { let serverChannel = EmbeddedChannel() let clientChannel = EmbeddedChannel()