diff --git a/Sources/AsyncHTTPClient/HTTPHandler.swift b/Sources/AsyncHTTPClient/HTTPHandler.swift index b644c1fa8..bb6779a91 100644 --- a/Sources/AsyncHTTPClient/HTTPHandler.swift +++ b/Sources/AsyncHTTPClient/HTTPHandler.swift @@ -765,12 +765,32 @@ extension TaskHandler: ChannelDuplexHandler { return context.eventLoop.makeSucceededFuture(()) } - return body.stream(HTTPClient.Body.StreamWriter { part in - context.eventLoop.assertInEventLoop() - return context.writeAndFlush(self.wrapOutboundOut(.body(part))).map { - self.callOutToDelegateFireAndForget(value: part, self.delegate.didSendRequestPart) + func doIt() -> EventLoopFuture { + return body.stream(HTTPClient.Body.StreamWriter { part in + let promise = self.task.eventLoop.makePromise(of: Void.self) + // All writes have to be switched to the channel EL if channel and task ELs differ + if context.eventLoop.inEventLoop { + context.writeAndFlush(self.wrapOutboundOut(.body(part)), promise: promise) + } else { + context.eventLoop.execute { + context.writeAndFlush(self.wrapOutboundOut(.body(part)), promise: promise) + } + } + + return promise.futureResult.map { + self.callOutToDelegateFireAndForget(value: part, self.delegate.didSendRequestPart) + } + }) + } + + // Callout to the user to start body streaming should be on task EL + if self.task.eventLoop.inEventLoop { + return doIt() + } else { + return self.task.eventLoop.flatSubmit { + doIt() } - }) + } } public func read(context: ChannelHandlerContext) { diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift index 7dcaaf200..8cb1f5b7f 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift @@ -98,6 +98,8 @@ extension HTTPClientTests { ("testAsyncShutdown", testAsyncShutdown), ("testValidationErrorsAreSurfaced", testValidationErrorsAreSurfaced), ("testUploadsReallyStream", testUploadsReallyStream), + ("testUploadStreamingCallinToleratedFromOtsideEL", testUploadStreamingCallinToleratedFromOtsideEL), + ("testUploadStreamingIsCalledOnTaskEL", testUploadStreamingIsCalledOnTaskEL), ] } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift index 6297767fa..0d045d551 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift @@ -1705,6 +1705,7 @@ class HTTPClientTests: XCTestCase { private let bodyPromises: [EventLoopPromise] private let endPromise: EventLoopPromise private var bodyPartsSeenSoFar = 0 + private var atEnd = false init(headPromise: EventLoopPromise, bodyPromises: [EventLoopPromise], @@ -1727,10 +1728,14 @@ class HTTPClientTests: XCTestCase { context.write(self.wrapOutboundOut(.head(.init(version: .init(major: 1, minor: 1), status: .ok))), promise: nil) context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: self.endPromise) + self.atEnd = true } } func handlerRemoved(context: ChannelHandlerContext) { + guard !self.atEnd else { + return + } struct NotFulfilledError: Error {} self.headPromise.fail(NotFulfilledError()) @@ -1753,10 +1758,7 @@ class HTTPClientTests: XCTestCase { let bodyPromises = (0..<16).map { _ in group.next().makePromise(of: ByteBuffer.self) } let endPromise = group.next().makePromise(of: Void.self) let sentOffAllBodyPartsPromise = group.next().makePromise(of: Void.self) - // Because of https://github.com/swift-server/async-http-client/issues/200 we also need to pull off a terrible - // hack and get the internal EventLoop out :(. Once the bug is fixed, this promise should only get the - // StreamWriter. - let streamWriterPromise = group.next().makePromise(of: (EventLoop, HTTPClient.Body.StreamWriter).self) + let streamWriterPromise = group.next().makePromise(of: HTTPClient.Body.StreamWriter.self) func makeServer() -> Channel? { return try? ServerBootstrap(group: group) @@ -1781,12 +1783,7 @@ class HTTPClientTests: XCTestCase { method: .POST, headers: ["transfer-encoding": "chunked"], body: .stream { streamWriter in - // Due to https://github.com/swift-server/async-http-client/issues/200 - // we also need to pull off a terrible hack and get the internal - // EventLoop out :(. Once the bug is fixed, this promise should only get - // the StreamWriter. - let currentEL = MultiThreadedEventLoopGroup.currentEventLoop! // HACK!! - streamWriterPromise.succeed((currentEL, streamWriter)) + streamWriterPromise.succeed(streamWriter) return sentOffAllBodyPartsPromise.futureResult }) } @@ -1811,13 +1808,69 @@ class HTTPClientTests: XCTestCase { buffer.clear() buffer.writeString(String(bodyChunkNumber, radix: 16)) XCTAssertEqual(1, buffer.readableBytes) - XCTAssertNoThrow(try streamWriter.0.flatSubmit { - streamWriter.1.write(.byteBuffer(buffer)) - }.wait()) + XCTAssertNoThrow(try streamWriter.write(.byteBuffer(buffer)).wait()) XCTAssertNoThrow(XCTAssertEqual(buffer, try bodyPromises[bodyChunkNumber].futureResult.wait())) } sentOffAllBodyPartsPromise.succeed(()) XCTAssertNoThrow(try endPromise.futureResult.wait()) XCTAssertNoThrow(try runningRequest.wait()) } + + func testUploadStreamingCallinToleratedFromOtsideEL() throws { + let httpBin = HTTPBin() + let httpClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup)) + defer { + XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpBin.shutdown()) + } + + let request = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/get", method: .POST, body: .stream(length: 4) { writer in + let promise = httpClient.eventLoopGroup.next().makePromise(of: Void.self) + // We have to toleare callins from any thread + DispatchQueue(label: "upload-streaming").async { + writer.write(.byteBuffer(ByteBuffer.of(string: "1234"))).whenComplete { _ in + promise.succeed(()) + } + } + return promise.futureResult + }) + XCTAssertNoThrow(try httpClient.execute(request: request).wait()) + } + + func testUploadStreamingIsCalledOnTaskEL() throws { + let group = getDefaultEventLoopGroup(numberOfThreads: 4) + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + + let httpBin = HTTPBin() + let httpClient = HTTPClient(eventLoopGroupProvider: .shared(group)) + defer { + XCTAssertNoThrow(try httpClient.syncShutdown()) + XCTAssertNoThrow(try httpBin.shutdown()) + } + + let el1 = group.next() + let el2 = group.next() + XCTAssertFalse(el1 === el2) + + do { + // Pre-populate pool with a connection on a different EL + let request = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/get", method: .GET) + XCTAssertNoThrow(try httpClient.execute(request: request, delegate: ResponseAccumulator(request: request), eventLoop: .delegateAndChannel(on: el2)).wait()) + } + + let body: HTTPClient.Body = .stream(length: 8) { writer in + XCTAssert(el1.inEventLoop) + let buffer = ByteBuffer.of(string: "1234") + return writer.write(.byteBuffer(buffer)).flatMap { + XCTAssert(el1.inEventLoop) + let buffer = ByteBuffer.of(string: "4321") + return writer.write(.byteBuffer(buffer)) + } + } + let request = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/post", method: .POST, body: body) + let response = httpClient.execute(request: request, delegate: ResponseAccumulator(request: request), eventLoop: .delegate(on: el1)) + XCTAssertNoThrow(try response.wait()) + } }