Skip to content

Commit 070c1e5

Browse files
authored
cpool: don't reuse connection if we sent close (#225)
Motivation: Previously, we'd only use the server's connection header to determine if we should close the connection or not. That's wrong because if we set `connection: close` ourselves, we must not reuse again. Modification: Set `TaskHandler.closing = false` if we send a close header. Result: More HTTP correctness.
1 parent cb9fd61 commit 070c1e5

File tree

5 files changed

+197
-39
lines changed

5 files changed

+197
-39
lines changed

Sources/AsyncHTTPClient/HTTPHandler.swift

+15-1
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,13 @@ internal class TaskHandler<Delegate: HTTPClientResponseDelegate>: RemovableChann
604604
var state: State = .idle
605605
var pendingRead = false
606606
var mayRead = true
607-
var closing = false
607+
var closing = false {
608+
didSet {
609+
assert(self.closing || !oldValue,
610+
"BUG in AsyncHTTPClient: TaskHandler.closing went from true (no conn reuse) to true (do reuse).")
611+
}
612+
}
613+
608614
let kind: HTTPClient.Request.Kind
609615

610616
init(task: HTTPClient.Task<Delegate.Response>,
@@ -736,6 +742,14 @@ extension TaskHandler: ChannelDuplexHandler {
736742

737743
head.headers = headers
738744

745+
if head.headers[canonicalForm: "connection"].map({ $0.lowercased() }).contains("close") {
746+
self.closing = true
747+
}
748+
// This assert can go away when (if ever!) the above `if` correctly handles other HTTP versions. For example
749+
// in HTTP/1.0, we need to treat the absence of a 'connection: keep-alive' as a close too.
750+
assert(head.version == HTTPVersion(major: 1, minor: 1),
751+
"Sending a request in HTTP version \(head.version) which is unsupported by the above `if`")
752+
739753
context.write(wrapOutboundOut(.head(head))).map {
740754
self.callOutToDelegateFireAndForget(value: head, self.delegate.didSendRequestHead)
741755
}.flatMap {

Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift

+5-4
Original file line numberDiff line numberDiff line change
@@ -139,11 +139,10 @@ class HTTPClientInternalTests: XCTestCase {
139139
}
140140

141141
let upload = try! httpClient.post(url: "http://localhost:\(httpBin.port)/post", body: body).wait()
142-
let bytes = upload.body.flatMap { $0.getData(at: 0, length: $0.readableBytes) }
143-
let data = try! JSONDecoder().decode(RequestInfo.self, from: bytes!)
142+
let data = upload.body.flatMap { try? JSONDecoder().decode(RequestInfo.self, from: $0) }
144143

145144
XCTAssertEqual(.ok, upload.status)
146-
XCTAssertEqual("id: 0id: 1id: 2id: 3id: 4id: 5id: 6id: 7id: 8id: 9", data.data)
145+
XCTAssertEqual("id: 0id: 1id: 2id: 3id: 4id: 5id: 6id: 7id: 8id: 9", data?.data)
147146
}
148147

149148
func testProxyStreamingFailure() throws {
@@ -466,7 +465,9 @@ class HTTPClientInternalTests: XCTestCase {
466465
XCTAssertNoThrow(try httpBin.shutdown())
467466
}
468467

469-
let req = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/get", method: .GET, headers: ["Connection": "close"], body: nil)
468+
let req = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/get",
469+
method: .GET,
470+
headers: ["X-Send-Back-Header-Connection": "close"], body: nil)
470471
_ = try! httpClient.execute(request: req).wait()
471472
let el = httpClient.eventLoopGroup.next()
472473
try! el.scheduleTask(in: .milliseconds(500)) {

Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift

+43-23
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ final class HTTPProxySimulator: ChannelInboundHandler, RemovableChannelHandler {
339339
}
340340

341341
internal struct HTTPResponseBuilder {
342-
let head: HTTPResponseHead
342+
var head: HTTPResponseHead
343343
var body: ByteBuffer?
344344

345345
init(_ version: HTTPVersion = HTTPVersion(major: 1, minor: 1), status: HTTPResponseStatus, headers: HTTPHeaders = HTTPHeaders()) {
@@ -357,8 +357,13 @@ internal struct HTTPResponseBuilder {
357357
}
358358
}
359359

360+
let globalRequestCounter = NIOAtomic<Int>.makeAtomic(value: 0)
361+
let globalConnectionCounter = NIOAtomic<Int>.makeAtomic(value: 0)
362+
360363
internal struct RequestInfo: Codable {
361-
let data: String
364+
var data: String
365+
var requestNumber: Int
366+
var connectionNumber: Int
362367
}
363368

364369
internal final class HttpBinHandler: ChannelInboundHandler {
@@ -367,16 +372,19 @@ internal final class HttpBinHandler: ChannelInboundHandler {
367372

368373
let channelPromise: EventLoopPromise<Channel>?
369374
var resps = CircularBuffer<HTTPResponseBuilder>()
370-
var closeAfterResponse = false
375+
var responseHeaders = HTTPHeaders()
371376
var delay: TimeAmount = .seconds(0)
372377
let creationDate = Date()
373378
let maxChannelAge: TimeAmount?
374379
var shouldClose = false
375380
var isServingRequest = false
381+
let myConnectionNumber: Int
382+
var currentRequestNumber: Int = -1
376383

377384
init(channelPromise: EventLoopPromise<Channel>? = nil, maxChannelAge: TimeAmount? = nil) {
378385
self.channelPromise = channelPromise
379386
self.maxChannelAge = maxChannelAge
387+
self.myConnectionNumber = globalConnectionCounter.add(1)
380388
}
381389

382390
func handlerAdded(context: ChannelHandlerContext) {
@@ -402,27 +410,31 @@ internal final class HttpBinHandler: ChannelInboundHandler {
402410
self.delay = .nanoseconds(0)
403411
}
404412

405-
if let connection = head.headers["Connection"].first {
406-
self.closeAfterResponse = (connection == "close")
407-
} else {
408-
self.closeAfterResponse = false
413+
for header in head.headers {
414+
let needle = "x-send-back-header-"
415+
if header.name.lowercased().starts(with: needle) {
416+
self.responseHeaders.add(name: String(header.name.dropFirst(needle.count)),
417+
value: header.value)
418+
}
409419
}
410420
}
411421

412422
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
413423
self.isServingRequest = true
414424
switch self.unwrapInboundIn(data) {
415425
case .head(let req):
426+
self.responseHeaders = HTTPHeaders()
427+
self.currentRequestNumber = globalRequestCounter.add(1)
416428
self.parseAndSetOptions(from: req)
417429
let urlComponents = URLComponents(string: req.uri)!
418430
switch urlComponents.percentEncodedPath {
419431
case "/":
420-
var headers = HTTPHeaders()
432+
var headers = self.responseHeaders
421433
headers.add(name: "X-Is-This-Slash", value: "Yes")
422434
self.resps.append(HTTPResponseBuilder(status: .ok, headers: headers))
423435
return
424436
case "/echo-uri":
425-
var headers = HTTPHeaders()
437+
var headers = self.responseHeaders
426438
headers.add(name: "X-Calling-URI", value: req.uri)
427439
self.resps.append(HTTPResponseBuilder(status: .ok, headers: headers))
428440
return
@@ -436,6 +448,13 @@ internal final class HttpBinHandler: ChannelInboundHandler {
436448
}
437449
self.resps.append(HTTPResponseBuilder(status: .ok))
438450
return
451+
case "/stats":
452+
var body = context.channel.allocator.buffer(capacity: 1)
453+
body.writeString("Just some stats mate.")
454+
var builder = HTTPResponseBuilder(status: .ok)
455+
builder.add(body)
456+
457+
self.resps.append(builder)
439458
case "/post":
440459
if req.method != .POST {
441460
self.resps.append(HTTPResponseBuilder(status: .methodNotAllowed))
@@ -444,29 +463,29 @@ internal final class HttpBinHandler: ChannelInboundHandler {
444463
self.resps.append(HTTPResponseBuilder(status: .ok))
445464
return
446465
case "/redirect/302":
447-
var headers = HTTPHeaders()
466+
var headers = self.responseHeaders
448467
headers.add(name: "location", value: "/ok")
449468
self.resps.append(HTTPResponseBuilder(status: .found, headers: headers))
450469
return
451470
case "/redirect/https":
452471
let port = self.value(for: "port", from: urlComponents.query!)
453-
var headers = HTTPHeaders()
472+
var headers = self.responseHeaders
454473
headers.add(name: "Location", value: "https://localhost:\(port)/ok")
455474
self.resps.append(HTTPResponseBuilder(status: .found, headers: headers))
456475
return
457476
case "/redirect/loopback":
458477
let port = self.value(for: "port", from: urlComponents.query!)
459-
var headers = HTTPHeaders()
478+
var headers = self.responseHeaders
460479
headers.add(name: "Location", value: "http://127.0.0.1:\(port)/echohostheader")
461480
self.resps.append(HTTPResponseBuilder(status: .found, headers: headers))
462481
return
463482
case "/redirect/infinite1":
464-
var headers = HTTPHeaders()
483+
var headers = self.responseHeaders
465484
headers.add(name: "Location", value: "/redirect/infinite2")
466485
self.resps.append(HTTPResponseBuilder(status: .found, headers: headers))
467486
return
468487
case "/redirect/infinite2":
469-
var headers = HTTPHeaders()
488+
var headers = self.responseHeaders
470489
headers.add(name: "Location", value: "/redirect/infinite1")
471490
self.resps.append(HTTPResponseBuilder(status: .found, headers: headers))
472491
return
@@ -528,15 +547,15 @@ internal final class HttpBinHandler: ChannelInboundHandler {
528547
if self.resps.isEmpty {
529548
return
530549
}
531-
let response = self.resps.removeFirst()
550+
var response = self.resps.removeFirst()
551+
response.head.headers.add(contentsOf: self.responseHeaders)
532552
context.write(wrapOutboundOut(.head(response.head)), promise: nil)
533-
if var body = response.body {
534-
let data = body.readData(length: body.readableBytes)!
535-
let serialized = try! JSONEncoder().encode(RequestInfo(data: String(decoding: data,
536-
as: Unicode.UTF8.self)))
537-
538-
var responseBody = context.channel.allocator.buffer(capacity: serialized.count)
539-
responseBody.writeBytes(serialized)
553+
if let body = response.body {
554+
let requestInfo = RequestInfo(data: String(buffer: body),
555+
requestNumber: self.currentRequestNumber,
556+
connectionNumber: self.myConnectionNumber)
557+
let responseBody = try! JSONEncoder().encodeAsByteBuffer(requestInfo,
558+
allocator: context.channel.allocator)
540559
context.write(wrapOutboundOut(.body(.byteBuffer(responseBody))), promise: nil)
541560
}
542561
context.eventLoop.scheduleTask(in: self.delay) {
@@ -549,7 +568,8 @@ internal final class HttpBinHandler: ChannelInboundHandler {
549568
self.isServingRequest = false
550569
switch result {
551570
case .success:
552-
if self.closeAfterResponse || self.shouldClose {
571+
if self.responseHeaders[canonicalForm: "X-Close-Connection"].contains("true") ||
572+
self.shouldClose {
553573
context.close(promise: nil)
554574
}
555575
case .failure(let error):

Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift

+4
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,10 @@ extension HTTPClientTests {
9999
("testValidationErrorsAreSurfaced", testValidationErrorsAreSurfaced),
100100
("testUploadsReallyStream", testUploadsReallyStream),
101101
("testUploadStreamingCallinToleratedFromOtsideEL", testUploadStreamingCallinToleratedFromOtsideEL),
102+
("testWeHandleUsSendingACloseHeaderCorrectly", testWeHandleUsSendingACloseHeaderCorrectly),
103+
("testWeHandleUsReceivingACloseHeaderCorrectly", testWeHandleUsReceivingACloseHeaderCorrectly),
104+
("testWeHandleUsSendingACloseHeaderAmongstOtherConnectionHeadersCorrectly", testWeHandleUsSendingACloseHeaderAmongstOtherConnectionHeadersCorrectly),
105+
("testWeHandleUsReceivingACloseHeaderAmongstOtherConnectionHeadersCorrectly", testWeHandleUsReceivingACloseHeaderAmongstOtherConnectionHeadersCorrectly),
102106
]
103107
}
104108
}

0 commit comments

Comments
 (0)