diff --git a/Sources/AsyncHTTPClient/ConnectionPool.swift b/Sources/AsyncHTTPClient/ConnectionPool.swift index 4bd689804..72a911c97 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool.swift @@ -113,23 +113,38 @@ final class ConnectionPool { self.scheme = .https case "unix": self.scheme = .unix - self.unixPath = request.url.baseURL?.path ?? request.url.path + case "http+unix": + self.scheme = .http_unix + case "https+unix": + self.scheme = .https_unix default: fatalError("HTTPClient.Request scheme should already be a valid one") } self.port = request.port self.host = request.host + self.unixPath = request.socketPath } var scheme: Scheme var host: String var port: Int - var unixPath: String = "" + var unixPath: String enum Scheme: Hashable { case http case https case unix + case http_unix + case https_unix + + var requiresTLS: Bool { + switch self { + case .https, .https_unix: + return true + default: + return false + } + } } } } @@ -433,7 +448,7 @@ class HTTP1ConnectionProvider { private func makeChannel(preference: HTTPClient.EventLoopPreference) -> EventLoopFuture { let eventLoop = preference.bestEventLoop ?? self.eventLoop - let requiresTLS = self.key.scheme == .https + let requiresTLS = self.key.scheme.requiresTLS let bootstrap: NIOClientTCPBootstrap do { bootstrap = try NIOClientTCPBootstrap.makeHTTPClientBootstrapBase(on: eventLoop, host: self.key.host, port: self.key.port, requiresTLS: requiresTLS, configuration: self.configuration) @@ -446,12 +461,12 @@ class HTTP1ConnectionProvider { case .http, .https: let address = HTTPClient.resolveAddress(host: self.key.host, port: self.key.port, proxy: self.configuration.proxy) channel = bootstrap.connect(host: address.host, port: address.port) - case .unix: + case .unix, .http_unix, .https_unix: channel = bootstrap.connect(unixDomainSocketPath: self.key.unixPath) } return channel.flatMap { channel in - let requiresSSLHandler = self.configuration.proxy != nil && self.key.scheme == .https + let requiresSSLHandler = self.configuration.proxy != nil && self.key.scheme.requiresTLS let handshakePromise = channel.eventLoop.makePromise(of: Void.self) channel.pipeline.addSSLHandlerIfNeeded(for: self.key, tlsConfiguration: self.configuration.tlsConfiguration, addSSLClient: requiresSSLHandler, handshakePromise: handshakePromise) diff --git a/Sources/AsyncHTTPClient/HTTPClient.swift b/Sources/AsyncHTTPClient/HTTPClient.swift index 49619a814..a119ce90d 100644 --- a/Sources/AsyncHTTPClient/HTTPClient.swift +++ b/Sources/AsyncHTTPClient/HTTPClient.swift @@ -654,7 +654,7 @@ extension ChannelPipeline { } func addSSLHandlerIfNeeded(for key: ConnectionPool.Key, tlsConfiguration: TLSConfiguration?, addSSLClient: Bool, handshakePromise: EventLoopPromise) { - guard key.scheme == .https else { + guard key.scheme.requiresTLS else { handshakePromise.succeed(()) return } @@ -665,7 +665,7 @@ extension ChannelPipeline { let tlsConfiguration = tlsConfiguration ?? TLSConfiguration.forClient() let context = try NIOSSLContext(configuration: tlsConfiguration) handlers = [ - try NIOSSLClientHandler(context: context, serverHostname: key.host.isIPAddress ? nil : key.host), + try NIOSSLClientHandler(context: context, serverHostname: (key.host.isIPAddress || key.host.isEmpty) ? nil : key.host), TLSEventsHandler(completionPromise: handshakePromise), ] } else { @@ -726,6 +726,7 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { private enum Code: Equatable { case invalidURL case emptyHost + case missingSocketPath case alreadyShutdown case emptyScheme case unsupportedScheme(String) @@ -758,6 +759,8 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { public static let invalidURL = HTTPClientError(code: .invalidURL) /// URL does not contain host. public static let emptyHost = HTTPClientError(code: .emptyHost) + /// URL does not contain a socketPath as a host for http(s)+unix shemes. + public static let missingSocketPath = HTTPClientError(code: .missingSocketPath) /// Client is shutdown and cannot be used for new requests. public static let alreadyShutdown = HTTPClientError(code: .alreadyShutdown) /// URL does not contain scheme. diff --git a/Sources/AsyncHTTPClient/HTTPHandler.swift b/Sources/AsyncHTTPClient/HTTPHandler.swift index b02396b75..77220f068 100644 --- a/Sources/AsyncHTTPClient/HTTPHandler.swift +++ b/Sources/AsyncHTTPClient/HTTPHandler.swift @@ -99,20 +99,27 @@ extension HTTPClient { public struct Request { /// Represent kind of Request enum Kind { + enum UnixScheme { + case baseURL + case http_unix + case https_unix + } + /// Remote host request. case host /// UNIX Domain Socket HTTP request. - case unixSocket + case unixSocket(_ scheme: UnixScheme) private static var hostSchemes = ["http", "https"] - private static var unixSchemes = ["unix"] + private static var unixSchemes = ["unix", "http+unix", "https+unix"] init(forScheme scheme: String) throws { - if Kind.host.supports(scheme: scheme) { - self = .host - } else if Kind.unixSocket.supports(scheme: scheme) { - self = .unixSocket - } else { + switch scheme { + case "http", "https": self = .host + case "unix": self = .unixSocket(.baseURL) + case "http+unix": self = .unixSocket(.http_unix) + case "https+unix": self = .unixSocket(.https_unix) + default: throw HTTPClientError.unsupportedScheme(scheme) } } @@ -129,6 +136,31 @@ extension HTTPClient { } } + func socketPathFromURL(_ url: URL) throws -> String { + switch self { + case .unixSocket(.baseURL): + return url.baseURL?.path ?? url.path + case .unixSocket: + guard let socketPath = url.host else { + throw HTTPClientError.missingSocketPath + } + return socketPath + case .host: + return "" + } + } + + func uriFromURL(_ url: URL) -> String { + switch self { + case .host: + return url.uri + case .unixSocket(.baseURL): + return url.baseURL != nil ? url.uri : "/" + case .unixSocket: + return url.uri + } + } + func supports(scheme: String) -> Bool { switch self { case .host: @@ -147,6 +179,10 @@ extension HTTPClient { public let scheme: String /// Remote host, resolved from `URL`. public let host: String + /// Socket path, resolved from `URL`. + let socketPath: String + /// URI composed of the path and query, resolved from `URL`. + let uri: String /// Request custom HTTP Headers, defaults to no headers. public var headers: HTTPHeaders /// Request body, defaults to no body. @@ -192,6 +228,7 @@ extension HTTPClient { /// - `emptyScheme` if URL does not contain HTTP scheme. /// - `unsupportedScheme` if URL does contains unsupported HTTP scheme. /// - `emptyHost` if URL does not contains a host. + /// - `missingSocketPath` if URL does not contains a socketPath as an encoded host. public init(url: URL, method: HTTPMethod = .GET, headers: HTTPHeaders = HTTPHeaders(), body: Body? = nil) throws { guard let scheme = url.scheme?.lowercased() else { throw HTTPClientError.emptyScheme @@ -199,6 +236,8 @@ extension HTTPClient { self.kind = try Kind(forScheme: scheme) self.host = try self.kind.hostFromURL(url) + self.socketPath = try self.kind.socketPathFromURL(url) + self.uri = self.kind.uriFromURL(url) self.redirectState = nil self.url = url @@ -210,7 +249,7 @@ extension HTTPClient { /// Whether request will be executed using secure socket. public var useTLS: Bool { - return self.scheme == "https" + return self.scheme == "https" || self.scheme == "https+unix" } /// Resolved port. @@ -712,19 +751,9 @@ extension TaskHandler: ChannelDuplexHandler { self.state = .idle let request = self.unwrapOutboundIn(data) - let uri: String - switch (self.kind, request.url.baseURL) { - case (.host, _): - uri = request.url.uri - case (.unixSocket, .none): - uri = "/" // we don't have a real path, the path we have is the path of the UNIX Domain Socket. - case (.unixSocket, .some(_)): - uri = request.url.uri - } - var head = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1), method: request.method, - uri: uri) + uri: request.uri) var headers = request.headers if !request.headers.contains(name: "Host") { diff --git a/Sources/AsyncHTTPClient/Utils.swift b/Sources/AsyncHTTPClient/Utils.swift index 9160653e1..7da957b07 100644 --- a/Sources/AsyncHTTPClient/Utils.swift +++ b/Sources/AsyncHTTPClient/Utils.swift @@ -64,7 +64,7 @@ extension ClientBootstrap { } else { let tlsConfiguration = configuration.tlsConfiguration ?? TLSConfiguration.forClient() let sslContext = try NIOSSLContext(configuration: tlsConfiguration) - let hostname = (!requiresTLS || host.isIPAddress) ? nil : host + let hostname = (!requiresTLS || host.isIPAddress || host.isEmpty) ? nil : host let tlsProvider = try NIOSSLClientTLSProvider(context: sslContext, serverHostname: hostname) return NIOClientTCPBootstrap(self, tls: tlsProvider) } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift index d1e109844..96d0d3fa4 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift @@ -234,7 +234,6 @@ internal final class HTTPBin { self.serverChannel = try! ServerBootstrap(group: self.group) .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) - .childChannelOption(ChannelOptions.socket(IPPROTO_TCP, TCP_NODELAY), value: 1) .serverChannelInitializer { channel in channel.pipeline.addHandler(activeConnCounterHandler) }.childChannelInitializer { channel in diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift index f13a4ca5e..b9a4af3e0 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift @@ -90,6 +90,8 @@ extension HTTPClientTests { ("testMakeSecondRequestWhilstFirstIsOngoing", testMakeSecondRequestWhilstFirstIsOngoing), ("testUDSBasic", testUDSBasic), ("testUDSSocketAndPath", testUDSSocketAndPath), + ("testHTTPPlusUNIX", testHTTPPlusUNIX), + ("testHTTPSPlusUNIX", testHTTPSPlusUNIX), ("testUseExistingConnectionOnDifferentEL", testUseExistingConnectionOnDifferentEL), ("testWeRecoverFromServerThatClosesTheConnectionOnUs", testWeRecoverFromServerThatClosesTheConnectionOnUs), ("testPoolClosesIdleConnections", testPoolClosesIdleConnections), diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift index ffb37bf77..5e1a403a1 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift @@ -1325,6 +1325,43 @@ class HTTPClientTests: XCTestCase { }) } + func testHTTPPlusUNIX() { + // Here, we're testing a URL where the UNIX domain socket is encoded as the host name + XCTAssertNoThrow(try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in + let localHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(path)) + defer { + XCTAssertNoThrow(try localHTTPBin.shutdown()) + } + guard let target = URL(string: "http+unix://\(path.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/echo-uri"), + let request = try? Request(url: target) else { + XCTFail("couldn't build URL for request") + return + } + XCTAssertNoThrow(XCTAssertEqual(["/echo-uri"[...]], + try self.defaultClient.execute(request: request).wait().headers[canonicalForm: "X-Calling-URI"])) + }) + } + + func testHTTPSPlusUNIX() { + // Here, we're testing a URL where the UNIX domain socket is encoded as the host name + XCTAssertNoThrow(try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in + let localHTTPBin = HTTPBin(ssl: true, bindTarget: .unixDomainSocket(path)) + let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration(certificateVerification: .none)) + defer { + XCTAssertNoThrow(try localClient.syncShutdown()) + XCTAssertNoThrow(try localHTTPBin.shutdown()) + } + guard let target = URL(string: "https+unix://\(path.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/echo-uri"), + let request = try? Request(url: target) else { + XCTFail("couldn't build URL for request") + return + } + XCTAssertNoThrow(XCTAssertEqual(["/echo-uri"[...]], + try localClient.execute(request: request).wait().headers[canonicalForm: "X-Calling-URI"])) + }) + } + func testUseExistingConnectionOnDifferentEL() throws { let threadCount = 16 let elg = getDefaultEventLoopGroup(numberOfThreads: threadCount)