From 75f7b54a5fb6d4fb6203d9ccdd6b8818684e0311 Mon Sep 17 00:00:00 2001 From: Cory Benfield Date: Thu, 16 Jan 2020 15:52:35 +0000 Subject: [PATCH] Provide better certificate verification callback (#171) Motivation: We've exposed the OpenSSL certificate verify callback for a while now. Unfortunately, this callback is fairly hostile to callers, as it is called repeatedly through the certificate verification process, instead of presenting the unified certificate chain to the user. BoringSSL has a better verification callback available that we use internally, but that is not exposed to users. It would be nice to give users the better API. Modifications: - Expose the BoringSSL verification callback to users. - Re-plumb the Security.framework verification to use our public API. Result: Easier certificate chain verification overrides. --- Sources/NIOSSL/NIOSSLClientHandler.swift | 35 ++- Sources/NIOSSL/SSLCallbacks.swift | 139 +++++++++ Sources/NIOSSL/SSLCertificate.swift | 41 ++- Sources/NIOSSL/SSLConnection.swift | 53 +++- Sources/NIOSSL/SSLContext.swift | 20 +- Sources/NIOSSL/SSLErrors.swift | 12 + ...rityFrameworkCertificateVerification.swift | 142 +++------ Tests/NIOSSLTests/ClientSNITests.swift | 4 +- .../NIOSSLIntegrationTest+XCTest.swift | 4 + Tests/NIOSSLTests/NIOSSLIntegrationTest.swift | 270 +++++++++++++++++- 10 files changed, 589 insertions(+), 131 deletions(-) diff --git a/Sources/NIOSSL/NIOSSLClientHandler.swift b/Sources/NIOSSL/NIOSSLClientHandler.swift index 24f19291..57fde8e4 100644 --- a/Sources/NIOSSL/NIOSSLClientHandler.swift +++ b/Sources/NIOSSL/NIOSSLClientHandler.swift @@ -31,6 +31,11 @@ private extension String { /// This handler can be used in channels that are acting as the client /// in the TLS dialog. For server connections, use the `NIOSSLServerHandler`. public final class NIOSSLClientHandler: NIOSSLHandler { + public convenience init(context: NIOSSLContext, serverHostname: String?) throws { + try self.init(context: context, serverHostname: serverHostname, optionalCustomVerificationCallback: nil) + } + + @available(*, deprecated, renamed: "init(context:serverHostname:customVerificationCallback:)") public init(context: NIOSSLContext, serverHostname: String?, verificationCallback: NIOSSLVerificationCallback? = nil) throws { guard let connection = context.createConnection() else { throw NIOSSLError.unableToAllocateBoringSSLObject @@ -39,7 +44,7 @@ public final class NIOSSLClientHandler: NIOSSLHandler { connection.setConnectState() if let serverHostname = serverHostname { if serverHostname.isIPAddress() { - throw BoringSSLError.invalidSNIName([]) + throw NIOSSLExtraError.cannotUseIPAddressInSNI(ipAddress: serverHostname) } // IP addresses must not be provided in the SNI extension, so filter them. @@ -52,4 +57,32 @@ public final class NIOSSLClientHandler: NIOSSLHandler { super.init(connection: connection, shutdownTimeout: context.configuration.shutdownTimeout) } + + + public convenience init(context: NIOSSLContext, serverHostname: String?, customVerificationCallback: @escaping NIOSSLCustomVerificationCallback) throws { + try self.init(context: context, serverHostname: serverHostname, optionalCustomVerificationCallback: customVerificationCallback) + } + + // This exists to handle the explosion of initializers we got when I tried to deprecate the first one. At least they all pass through one path now. + private init(context: NIOSSLContext, serverHostname: String?, optionalCustomVerificationCallback: NIOSSLCustomVerificationCallback?) throws { + guard let connection = context.createConnection() else { + throw NIOSSLError.unableToAllocateBoringSSLObject + } + + connection.setConnectState() + if let serverHostname = serverHostname { + if serverHostname.isIPAddress() { + throw NIOSSLExtraError.cannotUseIPAddressInSNI(ipAddress: serverHostname) + } + + // IP addresses must not be provided in the SNI extension, so filter them. + try connection.setServerName(name: serverHostname) + } + + if let verificationCallback = optionalCustomVerificationCallback { + connection.setCustomVerificationCallback(CustomVerifyManager(callback: verificationCallback)) + } + + super.init(connection: connection, shutdownTimeout: context.configuration.shutdownTimeout) + } } diff --git a/Sources/NIOSSL/SSLCallbacks.swift b/Sources/NIOSSL/SSLCallbacks.swift index e283d8ee..dbc299e7 100644 --- a/Sources/NIOSSL/SSLCallbacks.swift +++ b/Sources/NIOSSL/SSLCallbacks.swift @@ -58,9 +58,34 @@ public enum NIOSSLVerificationResult { /// or to store the `NIOSSLCertificate` somewhere for later consumption. The easiest way to be sure that the /// `NIOSSLCertificate` is safe to consume is to wait for a user event that shows the handshake as completed, /// or for channelInactive. +/// +/// warning: This callback uses the old-style OpenSSL callback behaviour and is excessively complex to program with. +/// Instead, prefer using the NIOSSLCustomVerificationCallback style which receives the entire trust chain at once, +/// and also supports asynchronous certificate verification. public typealias NIOSSLVerificationCallback = (NIOSSLVerificationResult, NIOSSLCertificate) -> NIOSSLVerificationResult +/// A custom verification callback that allows completely overriding the certificate verification logic of BoringSSL. +/// +/// This verification callback is called no more than once per connection attempt. It is invoked with two arguments: +/// +/// 1. The certificate chain presented by the peer, in the order the peer presented them (with the first certificate +/// being the leaf certificate presented by the peer). +/// 2. An `EventLoopPromise` that must be completed to signal the result of the verification. +/// +/// Please be cautious with calling out from this method. This method is always invoked on the event loop, +/// so you must not block or wait. However, you may perform asynchronous work by leaving the event loop context: +/// when the verification is complete you must complete the provided `EventLoopPromise`. +/// +/// This method must take care to ensure that it does not cause any `ChannelHandler` to recursively call back into +/// the `NIOSSLHandler` that triggered it, as making re-entrant calls into BoringSSL is not supported by SwiftNIO and +/// leads to undefined behaviour. It is acceptable to leave the event loop context and then call into the `NIOSSLHandler`, +/// as this will not be re-entrant. +/// +/// Note that setting this callback will override _all_ verification logic that BoringSSL provides. +public typealias NIOSSLCustomVerificationCallback = ([NIOSSLCertificate], EventLoopPromise) -> Void + + /// A callback that can be used to implement `SSLKEYLOGFILE` support. /// /// Wireshark can decrypt packet captures that contain encrypted TLS connections if they have access to the @@ -108,3 +133,117 @@ extension KeyLogCallbackManager { self.callback(self.scratchBuffer) } } + + +/// A struct that provides helpers for working with a NIOSSLCustomVerificationCallback. +internal struct CustomVerifyManager { + private var callback: CallbackType + + private var result: PendingResult = .notStarted + + init(callback: @escaping NIOSSLCustomVerificationCallback) { + self.callback = .public(callback) + } + + init(callback: @escaping InternalCallback) { + self.callback = .internal(callback) + } +} + + +extension CustomVerifyManager { + private enum PendingResult: Hashable { + case notStarted + + case pendingResult + + case complete(NIOSSLVerificationResult) + } +} + + +extension CustomVerifyManager { + mutating func process(on connection: SSLConnection) -> ssl_verify_result_t { + // First, check if we have a result. + switch self.result { + case .complete(.certificateVerified): + return ssl_verify_ok + case .complete(.failed): + return ssl_verify_invalid + case .pendingResult: + // Ask me again. + return ssl_verify_retry + case .notStarted: + // The rest of this method handles this case. + break + } + + self.result = .pendingResult + + // Ok, no result. We need a promise for the user to use to supply a result. + guard let eventLoop = connection.eventLoop else { + // No event loop. We cannot possibly be negotiating here. + preconditionFailure("No event loop present") + } + + let promise = eventLoop.makePromise(of: NIOSSLVerificationResult.self) + + // We need to attach our "do the thing" callback. This will always invoke the "ask me again" API, and it will do so in a separate + // event loop tick to avoid awkward re-entrancy with this method. + promise.futureResult.whenComplete { result in + // When we complete here we need to set our result state, and then ask to respin certificate verification. + // If we can't respin verification because we've dropped the parent handler, that's fine, no harm no foul. + // For that reason, we tolerate both the verify manager and the parent handler being nil. + eventLoop.execute { + // Note that we don't close over self here: that's to deal with the fact that this is a struct, and we don't want to + // escape the mutable ownership of self. + precondition(connection.customVerificationManager == nil || connection.customVerificationManager?.result == .some(.pendingResult)) + connection.customVerificationManager?.result = .complete(NIOSSLVerificationResult(result)) + connection.parentHandler?.asynchronousCertificateVerificationComplete() + } + } + + // Ok, let's do it. + self.callback.invoke(on: connection, promise: promise) + return ssl_verify_retry + } +} + + +extension CustomVerifyManager { + private enum CallbackType { + case `public`(NIOSSLCustomVerificationCallback) + case `internal`(InternalCallback) + + /// For user-supplied callbacks we need to give them public types. For internal ones, we just pass the + /// `EventLoopPromise` object through. + func invoke(on connection: SSLConnection, promise: EventLoopPromise) { + switch self { + case .public(let publicCallback): + do { + let certificates = try connection.peerCertificateChain() + publicCallback(certificates, promise) + } catch { + promise.fail(error) + } + + case .internal(let internalCallback): + internalCallback(promise) + } + } + } + + internal typealias InternalCallback = (EventLoopPromise) -> Void +} + + +extension NIOSSLVerificationResult { + init(_ result: Result) { + switch result { + case .success(let s): + self = s + case .failure: + self = .failed + } + } +} diff --git a/Sources/NIOSSL/SSLCertificate.swift b/Sources/NIOSSL/SSLCertificate.swift index dc033c54..813959a3 100644 --- a/Sources/NIOSSL/SSLCertificate.swift +++ b/Sources/NIOSSL/SSLCertificate.swift @@ -47,7 +47,7 @@ public class NIOSSLCertificate { case ipv6(in6_addr) } - private init(withReference ref: UnsafeMutablePointer) { + private init(withOwnedReference ref: UnsafeMutablePointer) { self._ref = UnsafeMutableRawPointer(ref) // erasing the type for @_implementationOnly import CNIOBoringSSL } @@ -73,7 +73,7 @@ public class NIOSSLCertificate { throw NIOSSLError.failedToLoadCertificate } - self.init(withReference: x509!) + self.init(withOwnedReference: x509!) } /// Create a NIOSSLCertificate from a buffer of bytes in either PEM or @@ -107,7 +107,36 @@ public class NIOSSLCertificate { throw NIOSSLError.failedToLoadCertificate } - self.init(withReference: ref!) + self.init(withOwnedReference: ref!) + } + + /// Create a NIOSSLCertificate from a buffer of bytes in either PEM or DER format. + internal convenience init(bytes ptr: UnsafeRawBufferPointer, format: NIOSSLSerializationFormats) throws { + // TODO(cory): + // The body of this method is exactly identical to the initializer above, except for the "withUnsafeBytes" call. + // ContiguousBytes would have been the lowest effort way to reduce this duplication, but we can't use it without + // bringing Foundation in. Probably we should use Sequence where Element == UInt8 and the withUnsafeContiguousBytesIfAvailable + // method, but that's a much more substantial refactor. Let's do it later. + let bio = CNIOBoringSSL_BIO_new_mem_buf(ptr.baseAddress, CInt(ptr.count))! + + defer { + CNIOBoringSSL_BIO_free(bio) + } + + let ref: UnsafeMutablePointer? + + switch format { + case .pem: + ref = CNIOBoringSSL_PEM_read_bio_X509(bio, nil, nil, nil) + case .der: + ref = CNIOBoringSSL_d2i_X509_bio(bio, nil) + } + + if ref == nil { + throw NIOSSLError.failedToLoadCertificate + } + + self.init(withOwnedReference: ref!) } /// Create a NIOSSLCertificate wrapping a pointer into BoringSSL. @@ -121,7 +150,7 @@ public class NIOSSLCertificate { /// In general, however, this function should be avoided in favour of one of the convenience /// initializers, which ensure that the lifetime of the `X509` object is better-managed. static func fromUnsafePointer(takingOwnership pointer: UnsafeMutablePointer) -> NIOSSLCertificate { - return NIOSSLCertificate(withReference: pointer) + return NIOSSLCertificate(withOwnedReference: pointer) } /// Get a sequence of the alternative names in the certificate. @@ -267,10 +296,10 @@ extension NIOSSLCertificate { throw NIOSSLError.failedToLoadCertificate } - var certificates = [NIOSSLCertificate(withReference: x509)] + var certificates = [NIOSSLCertificate(withOwnedReference: x509)] while let x = CNIOBoringSSL_PEM_read_bio_X509(bio, nil, nil, nil) { - certificates.append(.init(withReference: x)) + certificates.append(.init(withOwnedReference: x)) } let err = CNIOBoringSSL_ERR_peek_error() diff --git a/Sources/NIOSSL/SSLConnection.swift b/Sources/NIOSSL/SSLConnection.swift index 1429573b..ba42a13b 100644 --- a/Sources/NIOSSL/SSLConnection.swift +++ b/Sources/NIOSSL/SSLConnection.swift @@ -48,13 +48,15 @@ internal final class SSLConnection { private let ssl: OpaquePointer private let parentContext: NIOSSLContext private var bio: ByteBufferBIO? - private var verificationCallback: NIOSSLVerificationCallback? - internal var platformVerificationState: PlatformVerificationState = PlatformVerificationState() internal var expectedHostname: String? internal var role: ConnectionRole? internal var parentHandler: NIOSSLHandler? internal var eventLoop: EventLoop? + /// Deprecated in favour of customVerificationManager + private var verificationCallback: NIOSSLVerificationCallback? + internal var customVerificationManager: CustomVerifyManager? + /// Whether certificate hostnames should be validated. var validateHostnames: Bool { if case .fullVerification = parentContext.configuration.certificateVerification { @@ -116,7 +118,9 @@ internal final class SSLConnection { self.expectedHostname = name } - /// Sets the BoringSSL verification callback. + /// Sets the BoringSSL old-style verification callback. + /// + /// This is deprecated in favour of the new-style verification callback in SSLContext. func setVerificationCallback(_ callback: @escaping NIOSSLVerificationCallback) { // Store the verification callback. We need to do this to keep it alive throughout the connection. // We'll drop this when we're told that it's no longer needed to ensure we break the reference cycles @@ -156,6 +160,33 @@ internal final class SSLConnection { } } + func setCustomVerificationCallback(_ callbackManager: CustomVerifyManager) { + // Store the verification callback. We need to do this to keep it alive throughout the connection. + // We'll drop this when we're told that it's no longer needed to ensure we break the reference cycles + // that this callback inevitably produces. + assert(self.customVerificationManager == nil) + self.customVerificationManager = callbackManager + + // We need to know what the current mode is. + let currentMode = CNIOBoringSSL_SSL_get_verify_mode(self.ssl) + CNIOBoringSSL_SSL_set_custom_verify(self.ssl, currentMode) { ssl, outAlert in + guard let unwrappedSSL = ssl else { + preconditionFailure("Unexpected null pointer in custom verification callback. ssl: \(String(describing: ssl))") + } + + // Ok, this call may be a resumption of a previous negotiation. We need to check if our connection object has a pre-existing verifiation state. + guard let connectionPointer = CNIOBoringSSL_SSL_get_ex_data(unwrappedSSL, sslConnectionExDataIndex) else { + // Uh-ok, our application state is gone. Don't let this error silently pass, go bang. + preconditionFailure("Unable to find application data from SSL * \(unwrappedSSL), index \(sslConnectionExDataIndex)") + } + + let connection = Unmanaged.fromOpaque(connectionPointer).takeUnretainedValue() + + // We force unwrap the custom verification manager because for it to not be set is a programmer error. + return connection.customVerificationManager!.process(on: connection) + } + } + /// Sets whether renegotiation is supported. func setRenegotiationSupport(_ state: NIORenegotiationSupport) { var baseState: ssl_renegotiate_mode_t @@ -364,9 +395,10 @@ internal final class SSLConnection { /// Must only be called when the connection is no longer needed. The rest of this object /// preconditions on that being true, so we'll find out quickly when that's not the case. func close() { - /// Drop the verification callback. This breaks any reference cycles that are inevitably - /// created by this callback. + /// Drop the verification callbacks. This breaks any reference cycles that are inevitably + /// created by these callbacks. self.verificationCallback = nil + self.customVerificationManager = nil // Also drop the reference to the parent channel handler, which is a trivial reference cycle. self.parentHandler = nil @@ -422,6 +454,17 @@ extension SSLConnection { return try body(PeerCertificateChainBuffers(basePointer: stackPointer)) } + + /// The certificate chain presented by the peer. + func peerCertificateChain() throws -> [NIOSSLCertificate] { + return try self.withPeerCertificateChainBuffers { buffers in + guard let buffers = buffers else { + return [] + } + + return try buffers.map { try NIOSSLCertificate(bytes: $0, format: .der) } + } + } } extension SSLConnection.PeerCertificateChainBuffers: RandomAccessCollection { diff --git a/Sources/NIOSSL/SSLContext.swift b/Sources/NIOSSL/SSLContext.swift index bcdcc773..8c32d2b0 100644 --- a/Sources/NIOSSL/SSLContext.swift +++ b/Sources/NIOSSL/SSLContext.swift @@ -246,7 +246,20 @@ public final class NIOSSLContext { guard let ssl = CNIOBoringSSL_SSL_new(self.sslContext) else { return nil } - return SSLConnection(ownedSSL: ssl, parentContext: self) + + let conn = SSLConnection(ownedSSL: ssl, parentContext: self) + + // If we need to turn on the validation on Apple platforms, do it here. + #if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) + switch self.configuration.trustRoots { + case .some(.default), .none: + conn.setCustomVerificationCallback(CustomVerifyManager(callback: conn.performSecurityFrameworkValidation(promise:))) + case .some(.certificates), .some(.file): + break + } + #endif + + return conn } fileprivate func alpnSelectCallback(offeredProtocols: UnsafeBufferPointer) -> (index: Int, length: Int)? { @@ -397,8 +410,9 @@ extension NIOSSLContext { } private static func platformDefaultConfiguration(context: OpaquePointer) throws { - // Platform default trust is configured differently in different places. On Darwin we invoke Security.framework in a custom callback. + // Platform default trust is configured differently in different places. // On Linux, we use our searched heuristics to guess about where the platform trust store is. + // On Darwin, we use a custom callback that is set later, in createConnection #if os(Linux) let result = rootCAFilePath.withCString { rootCAFilePointer in rootCADirectoryPath.withCString { rootCADirectoryPointer in @@ -410,8 +424,6 @@ extension NIOSSLContext { let errorStack = BoringSSLError.buildErrorStack() throw BoringSSLError.unknownError(errorStack) } - #elseif os(macOS) || os(iOS) || os(watchOS) || os(tvOS) - CNIOBoringSSL_SSL_CTX_set_custom_verify(context, SSL_VERIFY_PEER, securityFrameworkCustomVerify) #endif } diff --git a/Sources/NIOSSL/SSLErrors.swift b/Sources/NIOSSL/SSLErrors.swift index 3e4474ff..c1664a05 100644 --- a/Sources/NIOSSL/SSLErrors.swift +++ b/Sources/NIOSSL/SSLErrors.swift @@ -210,6 +210,7 @@ extension NIOSSLExtraError { private enum BaseError: Equatable { case failedToValidateHostname case serverHostnameImpossibleToMatch + case cannotUseIPAddressInSNI } } @@ -221,15 +222,26 @@ extension NIOSSLExtraError { /// The server hostname provided by the user cannot match any names in the certificate due to containing invalid characters. public static let serverHostnameImpossibleToMatch = NIOSSLExtraError(baseError: .serverHostnameImpossibleToMatch, description: nil) + /// IP addresses may not be used in SNI. + public static let cannotUseIPAddressInSNI = NIOSSLExtraError(baseError: .cannotUseIPAddressInSNI, description: nil) + + @inline(never) internal static func failedToValidateHostname(expectedName: String) -> NIOSSLExtraError { let description = "Couldn't find \(expectedName) in certificate from peer" return NIOSSLExtraError(baseError: .failedToValidateHostname, description: description) } + @inline(never) internal static func serverHostnameImpossibleToMatch(hostname: String) -> NIOSSLExtraError { let description = "The server hostname \(hostname) cannot be matched due to containing non-DNS characters" return NIOSSLExtraError(baseError: .serverHostnameImpossibleToMatch, description: description) } + + @inline(never) + internal static func cannotUseIPAddressInSNI(ipAddress: String) -> NIOSSLExtraError { + let description = "IP addresses cannot validly be used for Server Name Indication, got \(ipAddress)" + return NIOSSLExtraError(baseError: .cannotUseIPAddressInSNI, description: description) + } } diff --git a/Sources/NIOSSL/SecurityFrameworkCertificateVerification.swift b/Sources/NIOSSL/SecurityFrameworkCertificateVerification.swift index 0069207e..658fec87 100644 --- a/Sources/NIOSSL/SecurityFrameworkCertificateVerification.swift +++ b/Sources/NIOSSL/SecurityFrameworkCertificateVerification.swift @@ -11,6 +11,7 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// +import NIO #if compiler(>=5.1) && compiler(<5.2) @_implementationOnly import CNIOBoringSSL @@ -18,126 +19,59 @@ import CNIOBoringSSL #endif -/// The current state of the platform verification helper, if one is in use. -/// -/// Only used on Apple platforms currently. -internal struct PlatformVerificationState { - #if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) - fileprivate var state: SecurityFrameworkVerificationState? = nil - #endif -} - // We can only use Security.framework to validate TLS certificates on Apple platforms. #if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) import Dispatch import Foundation import Security -/// A custom certificate verification function for BoringSSL that uses Security.framework to provide certificate verification. -/// -/// - parameters: -/// - ssl: The pointer to the SSL * object for this connection. -/// - outAlert: A C-style inout parameter that contains a pointer to an alert. Is assumed to be non-null. -internal func securityFrameworkCustomVerify(_ ssl: OpaquePointer?, _ outAlert: UnsafeMutablePointer?) -> ssl_verify_result_t { - guard let unwrappedSSL = ssl, let unwrappedOutAlert = outAlert else { - preconditionFailure("Unexpected null pointer in custom verification callback. ssl: \(String(describing: ssl)) outAlert: \(String(describing: outAlert))") - } - - // Ok, this call may be a resumption of a previous negotiation. We need to check if our connection object has a pre-existing verifiation state. - guard let connectionPointer = CNIOBoringSSL_SSL_get_ex_data(unwrappedSSL, sslConnectionExDataIndex) else { - // Uh-ok, our application state is gone. Don't let this error silently pass, go bang. - preconditionFailure("Unable to find application data from SSL * \(unwrappedSSL), index \(sslConnectionExDataIndex)") - } - - let connection = Unmanaged.fromOpaque(connectionPointer).takeUnretainedValue() - - do { - return try connection.performSecurityFrameworkValidation() - } catch { - unwrappedOutAlert.pointee = UInt8(SSL_AD_INTERNAL_ERROR) - return ssl_verify_invalid - } -} - -extension PlatformVerificationState { - fileprivate enum SecurityFrameworkVerificationState { - case pendingResult - - case complete(SecTrustResultType) - } -} - extension SSLConnection { - func performSecurityFrameworkValidation() throws -> ssl_verify_result_t { - // First, check whether we have an outstanding or completed query. If we do, don't do any other work. - switch self.platformVerificationState.state { - case .some(.complete(.proceed)), .some(.complete(.unspecified)): - // These two cases mean we have successfully validated the certificate. We're done! Wipe out the state so - // that if we need to reverify we can, and return the success. - self.platformVerificationState.state = nil - return ssl_verify_ok - case .some(.complete): - // Ok, this broader case means we failed. We're still done, but return the failure instead. - self.platformVerificationState.state = nil - return ssl_verify_invalid - case .some(.pendingResult): - // We've got a validation attempt outstanding. Tell BoringSSL to hold its horses. - return ssl_verify_retry - case .none: - // No verification outstanding: do more work. - break - } + func performSecurityFrameworkValidation(promise: EventLoopPromise) { + do { + // Ok, time to kick off a validation. Let's get some certificate buffers. + let certificates: [SecCertificate] = try self.withPeerCertificateChainBuffers { buffers in + guard let buffers = buffers else { + throw NIOSSLError.unableToValidateCertificate + } + + return try buffers.map { buffer in + let data = Data(bytes: buffer.baseAddress!, count: buffer.count) + guard let cert = SecCertificateCreateWithData(nil, data as CFData) else { + throw NIOSSLError.unableToValidateCertificate + } + return cert + } + } - // Ok, time to kick off a validation. Let's get some certificate buffers. - let certificates: [SecCertificate] = try self.withPeerCertificateChainBuffers { buffers in - guard let buffers = buffers else { + // This force-unwrap is safe as we must have decided if we're a client or a server before validation. + var trust: SecTrust? = nil + var result: OSStatus + let policy = SecPolicyCreateSSL(self.role! == .client, self.expectedHostname as CFString?) + result = SecTrustCreateWithCertificates(certificates as CFArray, policy, &trust) + guard result == errSecSuccess, let actualTrust = trust else { throw NIOSSLError.unableToValidateCertificate } - return try buffers.map { buffer in - let data = Data(bytes: buffer.baseAddress!, count: buffer.count) - guard let cert = SecCertificateCreateWithData(nil, data as CFData) else { - throw NIOSSLError.unableToValidateCertificate + // We create a DispatchQueue here to be called back on, as this validation may perform network activity. + let callbackQueue = DispatchQueue(label: "io.swiftnio.ssl.validationCallbackQueue") + + result = SecTrustEvaluateAsync(actualTrust, callbackQueue) { (_, result) in + switch result { + case .proceed, .unspecified: + // These two cases mean we have successfully validated the certificate. We're done! + promise.succeed(.certificateVerified) + default: + // Oops, we failed. + promise.succeed(.failed) } - return cert } - } - // This force-unwrap is safe as we must have decided if we're a client or a server before validation. - var trust: SecTrust? = nil - var result: OSStatus - let policy = SecPolicyCreateSSL(self.role! == .client, self.expectedHostname as CFString?) - result = SecTrustCreateWithCertificates(certificates as CFArray, policy, &trust) - guard result == errSecSuccess, let actualTrust = trust else { - throw NIOSSLError.unableToValidateCertificate - } - - // We create a DispatchQueue here to be called back on, as this validation may perform network activity. - let callbackQueue = DispatchQueue(label: "io.swiftnio.ssl.validationCallbackQueue") - - // Now we need to grab some things we need in the callback block. Specifically, we need the parent handler - // and the event loop it belongs to. This is because we cannot safely access these things from inside the - // block, and we need the eventLoop to get back onto a safe thread. - // - // We don't hold these references weak because we are ok with keeping the handler alive longer than necessary - // in the rare case that the handler is removed before the callback completes. - // The force-unwrap is safe, as we cannot be midway through handshaking before the connection has become active. - let eventLoop = self.eventLoop! - - result = SecTrustEvaluateAsync(actualTrust, callbackQueue) { (_, result) in - // When we complete here we need to set our result state, and then ask to respin certificate verification. - // If we can't respin verification because we've dropped the parent handler, that's fine, no harm no foul. - eventLoop.execute { - self.platformVerificationState.state = .complete(result) - self.parentHandler?.asynchronousCertificateVerificationComplete() + guard result == errSecSuccess else { + throw NIOSSLError.unableToValidateCertificate } + } catch { + promise.fail(error) } - - guard result == errSecSuccess else { - throw NIOSSLError.unableToValidateCertificate - } - - return ssl_verify_retry } } diff --git a/Tests/NIOSSLTests/ClientSNITests.swift b/Tests/NIOSSLTests/ClientSNITests.swift index ca743a4b..a1e995b7 100644 --- a/Tests/NIOSSLTests/ClientSNITests.swift +++ b/Tests/NIOSSLTests/ClientSNITests.swift @@ -83,7 +83,7 @@ class ClientSNITests: XCTestCase { do { _ = try NIOSSLClientHandler(context: context, serverHostname: "192.168.0.1") XCTFail("Created client handler with invalid SNI name") - } catch BoringSSLError.invalidSNIName { + } catch let err as NIOSSLExtraError where err == NIOSSLExtraError.cannotUseIPAddressInSNI { // All fine. } } @@ -94,7 +94,7 @@ class ClientSNITests: XCTestCase { do { _ = try NIOSSLClientHandler(context: context, serverHostname: "fe80::200:f8ff:fe21:67cf") XCTFail("Created client handler with invalid SNI name") - } catch BoringSSLError.invalidSNIName { + } catch let err as NIOSSLExtraError where err == NIOSSLExtraError.cannotUseIPAddressInSNI { // All fine. } } diff --git a/Tests/NIOSSLTests/NIOSSLIntegrationTest+XCTest.swift b/Tests/NIOSSLTests/NIOSSLIntegrationTest+XCTest.swift index 046dfee8..87765dda 100644 --- a/Tests/NIOSSLTests/NIOSSLIntegrationTest+XCTest.swift +++ b/Tests/NIOSSLTests/NIOSSLIntegrationTest+XCTest.swift @@ -47,6 +47,10 @@ extension NIOSSLIntegrationTest { ("testFlushPendingReadsOnCloseNotify", testFlushPendingReadsOnCloseNotify), ("testForcingVerificationFailure", testForcingVerificationFailure), ("testExtractingCertificates", testExtractingCertificates), + ("testForcingVerificationFailureNewCallback", testForcingVerificationFailureNewCallback), + ("testErroringNewVerificationCallback", testErroringNewVerificationCallback), + ("testNewCallbackCanDelayHandshake", testNewCallbackCanDelayHandshake), + ("testExtractingCertificatesNewCallback", testExtractingCertificatesNewCallback), ("testRepeatedClosure", testRepeatedClosure), ("testClosureTimeout", testClosureTimeout), ("testReceivingGibberishAfterAttemptingToClose", testReceivingGibberishAfterAttemptingToClose), diff --git a/Tests/NIOSSLTests/NIOSSLIntegrationTest.swift b/Tests/NIOSSLTests/NIOSSLIntegrationTest.swift index 1ff76f11..1c74127c 100644 --- a/Tests/NIOSSLTests/NIOSSLIntegrationTest.swift +++ b/Tests/NIOSSLTests/NIOSSLIntegrationTest.swift @@ -18,6 +18,7 @@ import XCTest #else import CNIOBoringSSL #endif +import NIOConcurrencyHelpers import NIO @testable import NIOSSL import NIOTLS @@ -292,21 +293,70 @@ internal func clientTLSChannel(context: NIOSSLContext, group: EventLoopGroup, connectingTo: SocketAddress, serverHostname: String? = nil, - verificationCallback: NIOSSLVerificationCallback? = nil, file: StaticString = #file, line: UInt = #line) throws -> Channel { + func handlerFactory() throws -> NIOSSLClientHandler { + return try NIOSSLClientHandler(context: context, serverHostname: serverHostname) + } + + return try _clientTLSChannel(context: context, preHandlers: preHandlers, postHandlers: postHandlers, group: group, connectingTo: connectingTo, handlerFactory: handlerFactory) +} + + +@available(*, deprecated, renamed: "clientTLSChannel(context:preHandlers:postHandlers:group:connectingTo:serverHostname:customVerificationCallback:file:line:)") +internal func clientTLSChannel(context: NIOSSLContext, + preHandlers: [ChannelHandler], + postHandlers: [ChannelHandler], + group: EventLoopGroup, + connectingTo: SocketAddress, + serverHostname: String? = nil, + verificationCallback: @escaping NIOSSLVerificationCallback, + file: StaticString = #file, + line: UInt = #line) throws -> Channel { + func handlerFactory() throws -> NIOSSLClientHandler { + return try NIOSSLClientHandler(context: context, serverHostname: serverHostname, verificationCallback: verificationCallback) + } + + return try _clientTLSChannel(context: context, preHandlers: preHandlers, postHandlers: postHandlers, group: group, connectingTo: connectingTo, handlerFactory: handlerFactory) +} + + +internal func clientTLSChannel(context: NIOSSLContext, + preHandlers: [ChannelHandler], + postHandlers: [ChannelHandler], + group: EventLoopGroup, + connectingTo: SocketAddress, + serverHostname: String? = nil, + customVerificationCallback: @escaping NIOSSLCustomVerificationCallback, + file: StaticString = #file, + line: UInt = #line) throws -> Channel { + func handlerFactory() throws -> NIOSSLClientHandler { + return try NIOSSLClientHandler(context: context, serverHostname: serverHostname, customVerificationCallback: customVerificationCallback) + } + + return try _clientTLSChannel(context: context, preHandlers: preHandlers, postHandlers: postHandlers, group: group, connectingTo: connectingTo, handlerFactory: handlerFactory) +} + +fileprivate func _clientTLSChannel(context: NIOSSLContext, + preHandlers: [ChannelHandler], + postHandlers: [ChannelHandler], + group: EventLoopGroup, + connectingTo: SocketAddress, + handlerFactory: @escaping () throws -> NIOSSLClientHandler, + file: StaticString = #file, + line: UInt = #line) throws -> Channel { return try assertNoThrowWithValue(ClientBootstrap(group: group) .channelInitializer { channel in let results = preHandlers.map { channel.pipeline.addHandler($0) } return EventLoopFuture.andAllSucceed(results, on: results.first?.eventLoop ?? group.next()).flatMapThrowing { - try NIOSSLClientHandler(context: context, serverHostname: serverHostname, verificationCallback: verificationCallback) - }.flatMap { - channel.pipeline.addHandler($0) - }.flatMap { - let results = postHandlers.map { channel.pipeline.addHandler($0) } - return EventLoopFuture.andAllSucceed(results, on: results.first?.eventLoop ?? group.next()) - } - }.connect(to: connectingTo).wait(), file: file, line: line) + try handlerFactory() + }.flatMap { + channel.pipeline.addHandler($0) + }.flatMap { + let results = postHandlers.map { channel.pipeline.addHandler($0) } + return EventLoopFuture.andAllSucceed(results, on: results.first?.eventLoop ?? group.next()) + } + }.connect(to: connectingTo).wait(), file: file, line: line) } class NIOSSLIntegrationTest: XCTestCase { @@ -1161,6 +1211,7 @@ class NIOSSLIntegrationTest: XCTestCase { XCTAssertEqual(readData.readString(length: readData.readableBytes)!, "Hello") } + @available(*, deprecated, message: "Testing deprecated API surface") func testForcingVerificationFailure() throws { let context = try configuredSSLContext() @@ -1209,6 +1260,7 @@ class NIOSSLIntegrationTest: XCTestCase { } } + @available(*, deprecated, message: "Testing deprecated API surface") func testExtractingCertificates() throws { let context = try configuredSSLContext() @@ -1249,6 +1301,206 @@ class NIOSSLIntegrationTest: XCTestCase { XCTAssertEqual(certificates.count, 1) } + func testForcingVerificationFailureNewCallback() throws { + let context = try configuredSSLContext() + + let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + + let serverChannel: Channel = try serverTLSChannel(context: context, handlers: [], group: group) + defer { + XCTAssertNoThrow(try serverChannel.close().wait()) + } + + let errorHandler = ErrorCatcher() + let clientChannel = try clientTLSChannel(context: try configuredClientContext(), + preHandlers: [], + postHandlers: [errorHandler], + group: group, + connectingTo: serverChannel.localAddress!, + customVerificationCallback: { _, promise in + promise.succeed(.failed) + }) + + var originalBuffer = clientChannel.allocator.buffer(capacity: 5) + originalBuffer.writeString("Hello") + let writeFuture = clientChannel.writeAndFlush(originalBuffer) + let errorsFuture: EventLoopFuture<[NIOSSLError]> = writeFuture.recover { (_: Error) in + // We're swallowing errors here, on purpose, because we'll definitely + // hit them. + return () + }.map { + return errorHandler.errors + } + let actualErrors = try errorsFuture.wait() + + // This write will have failed, but that's fine: we just want it as a signal that + // the handshake is done so we can make our assertions. + XCTAssertEqual(actualErrors.count, 1) + switch actualErrors.first! { + case .handshakeFailed: + // expected + break + case let error: + XCTFail("Unexpected error: \(error)") + } + } + + func testErroringNewVerificationCallback() throws { + enum LocalError: Error { + case kaboom + } + + let context = try configuredSSLContext() + + let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + + let serverChannel: Channel = try serverTLSChannel(context: context, handlers: [], group: group) + defer { + XCTAssertNoThrow(try serverChannel.close().wait()) + } + + let errorHandler = ErrorCatcher() + let clientChannel = try clientTLSChannel(context: try configuredClientContext(), + preHandlers: [], + postHandlers: [errorHandler], + group: group, + connectingTo: serverChannel.localAddress!, + customVerificationCallback: { _, promise in + promise.fail(LocalError.kaboom) + }) + + var originalBuffer = clientChannel.allocator.buffer(capacity: 5) + originalBuffer.writeString("Hello") + let writeFuture = clientChannel.writeAndFlush(originalBuffer) + let errorsFuture: EventLoopFuture<[NIOSSLError]> = writeFuture.recover { (_: Error) in + // We're swallowing errors here, on purpose, because we'll definitely + // hit them. + return () + }.map { + return errorHandler.errors + } + let actualErrors = try errorsFuture.wait() + + // This write will have failed, but that's fine: we just want it as a signal that + // the handshake is done so we can make our assertions. + XCTAssertEqual(actualErrors.count, 1) + switch actualErrors.first! { + case .handshakeFailed: + // expected + break + case let error: + XCTFail("Unexpected error: \(error)") + } + } + + func testNewCallbackCanDelayHandshake() throws { + let context = try configuredSSLContext() + + let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + + + var completionPromiseFired: Bool = false + let completionPromiseFiredLock = Lock() + + let completionPromise: EventLoopPromise = group.next().makePromise() + completionPromise.futureResult.whenComplete { _ in + completionPromiseFiredLock.withLock { + completionPromiseFired = true + } + } + + + let serverChannel: Channel = try serverTLSChannel(context: context, handlers: [SimpleEchoServer()], group: group) + defer { + XCTAssertNoThrow(try serverChannel.close().wait()) + } + + var handshakeCompletePromise: EventLoopPromise? = nil + let handshakeFiredPromise: EventLoopPromise = group.next().makePromise() + + let clientChannel = try clientTLSChannel(context: configuredClientContext(), + preHandlers: [], + postHandlers: [PromiseOnReadHandler(promise: completionPromise)], + group: group, + connectingTo: serverChannel.localAddress!, + serverHostname: "localhost", + customVerificationCallback: { innerCertificates, promise in + handshakeCompletePromise = promise + handshakeFiredPromise.succeed(()) + }) + defer { + XCTAssertNoThrow(try clientChannel.close().wait()) + } + + var originalBuffer = clientChannel.allocator.buffer(capacity: 5) + originalBuffer.writeString("Hello") + clientChannel.writeAndFlush(originalBuffer, promise: nil) + + // This has driven the handshake to begin, so we can wait for that. + XCTAssertNoThrow(try handshakeFiredPromise.futureResult.wait()) + + // We can now check whether the completion promise has fired: it should not have. + completionPromiseFiredLock.withLock { + XCTAssertFalse(completionPromiseFired) + } + + // Ok, allow the handshake to run. + handshakeCompletePromise!.succeed(.certificateVerified) + + let newBuffer = try completionPromise.futureResult.wait() + XCTAssertTrue(completionPromiseFired) + XCTAssertEqual(newBuffer, originalBuffer) + } + + func testExtractingCertificatesNewCallback() throws { + let context = try configuredSSLContext() + + let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + + let completionPromise: EventLoopPromise = group.next().makePromise() + + let serverChannel: Channel = try serverTLSChannel(context: context, handlers: [SimpleEchoServer()], group: group) + defer { + XCTAssertNoThrow(try serverChannel.close().wait()) + } + + var certificates = [NIOSSLCertificate]() + let clientChannel = try clientTLSChannel(context: configuredClientContext(), + preHandlers: [], + postHandlers: [PromiseOnReadHandler(promise: completionPromise)], + group: group, + connectingTo: serverChannel.localAddress!, + serverHostname: "localhost", + customVerificationCallback: { innerCertificates, promise in + certificates = innerCertificates + promise.succeed(.certificateVerified) + }) + defer { + XCTAssertNoThrow(try clientChannel.close().wait()) + } + + var originalBuffer = clientChannel.allocator.buffer(capacity: 5) + originalBuffer.writeString("Hello") + XCTAssertNoThrow(try clientChannel.writeAndFlush(originalBuffer).wait()) + + let newBuffer = try completionPromise.futureResult.wait() + XCTAssertEqual(newBuffer, originalBuffer) + + XCTAssertEqual(certificates, [NIOSSLIntegrationTest.cert]) + } + func testRepeatedClosure() throws { let serverChannel = EmbeddedChannel() let clientChannel = EmbeddedChannel()