Skip to content

Commit 606ab0e

Browse files
authored
check body length (swift-server#255)
1 parent f2aef45 commit 606ab0e

8 files changed

+131
-34
lines changed

Diff for: Sources/AsyncHTTPClient/HTTPClient.swift

+6-3
Original file line numberDiff line numberDiff line change
@@ -925,6 +925,7 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible {
925925
case uncleanShutdown
926926
case traceRequestWithBody
927927
case invalidHeaderFieldNames([String])
928+
case bodyLengthMismatch
928929
}
929930

930931
private var code: Code
@@ -969,10 +970,12 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible {
969970
public static let redirectLimitReached = HTTPClientError(code: .redirectLimitReached)
970971
/// Redirect Cycle detected.
971972
public static let redirectCycleDetected = HTTPClientError(code: .redirectCycleDetected)
972-
/// Unclean shutdown
973+
/// Unclean shutdown.
973974
public static let uncleanShutdown = HTTPClientError(code: .uncleanShutdown)
974-
/// A body was sent in a request with method TRACE
975+
/// A body was sent in a request with method TRACE.
975976
public static let traceRequestWithBody = HTTPClientError(code: .traceRequestWithBody)
976-
/// Header field names contain invalid characters
977+
/// Header field names contain invalid characters.
977978
public static func invalidHeaderFieldNames(_ names: [String]) -> HTTPClientError { return HTTPClientError(code: .invalidHeaderFieldNames(names)) }
979+
/// Body length is not equal to `Content-Length`.
980+
public static let bodyLengthMismatch = HTTPClientError(code: .bodyLengthMismatch)
978981
}

Diff for: Sources/AsyncHTTPClient/HTTPHandler.swift

+29-14
Original file line numberDiff line numberDiff line change
@@ -641,7 +641,7 @@ internal class TaskHandler<Delegate: HTTPClientResponseDelegate>: RemovableChann
641641
case head
642642
case redirected(HTTPResponseHead, URL)
643643
case body
644-
case end
644+
case endOrError
645645
}
646646

647647
let task: HTTPClient.Task<Delegate.Response>
@@ -651,6 +651,8 @@ internal class TaskHandler<Delegate: HTTPClientResponseDelegate>: RemovableChann
651651
let logger: Logger // We are okay to store the logger here because a TaskHandler is just for one request.
652652

653653
var state: State = .idle
654+
var expectedBodyLength: Int?
655+
var actualBodyLength: Int = 0
654656
var pendingRead = false
655657
var mayRead = true
656658
var closing = false {
@@ -785,7 +787,7 @@ extension TaskHandler: ChannelDuplexHandler {
785787
} catch {
786788
promise?.fail(error)
787789
self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError)
788-
self.state = .end
790+
self.state = .endOrError
789791
return
790792
}
791793

@@ -799,12 +801,23 @@ extension TaskHandler: ChannelDuplexHandler {
799801
assert(head.version == HTTPVersion(major: 1, minor: 1),
800802
"Sending a request in HTTP version \(head.version) which is unsupported by the above `if`")
801803

804+
let contentLengths = head.headers[canonicalForm: "content-length"]
805+
assert(contentLengths.count <= 1)
806+
807+
self.expectedBodyLength = contentLengths.first.flatMap { Int($0) }
808+
802809
context.write(wrapOutboundOut(.head(head))).map {
803810
self.callOutToDelegateFireAndForget(value: head, self.delegate.didSendRequestHead)
804811
}.flatMap {
805812
self.writeBody(request: request, context: context)
806813
}.flatMap {
807814
context.eventLoop.assertInEventLoop()
815+
if let expectedBodyLength = self.expectedBodyLength, expectedBodyLength != self.actualBodyLength {
816+
self.state = .endOrError
817+
let error = HTTPClientError.bodyLengthMismatch
818+
self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError)
819+
return context.eventLoop.makeFailedFuture(error)
820+
}
808821
return context.writeAndFlush(self.wrapOutboundOut(.end(nil)))
809822
}.map {
810823
context.eventLoop.assertInEventLoop()
@@ -813,10 +826,10 @@ extension TaskHandler: ChannelDuplexHandler {
813826
}.flatMapErrorThrowing { error in
814827
context.eventLoop.assertInEventLoop()
815828
switch self.state {
816-
case .end:
829+
case .endOrError:
817830
break
818831
default:
819-
self.state = .end
832+
self.state = .endOrError
820833
self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError)
821834
}
822835
throw error
@@ -833,9 +846,11 @@ extension TaskHandler: ChannelDuplexHandler {
833846
let promise = self.task.eventLoop.makePromise(of: Void.self)
834847
// All writes have to be switched to the channel EL if channel and task ELs differ
835848
if context.eventLoop.inEventLoop {
849+
self.actualBodyLength += part.readableBytes
836850
context.writeAndFlush(self.wrapOutboundOut(.body(part)), promise: promise)
837851
} else {
838852
context.eventLoop.execute {
853+
self.actualBodyLength += part.readableBytes
839854
context.writeAndFlush(self.wrapOutboundOut(.body(part)), promise: promise)
840855
}
841856
}
@@ -898,12 +913,12 @@ extension TaskHandler: ChannelDuplexHandler {
898913
case .end:
899914
switch self.state {
900915
case .redirected(let head, let redirectURL):
901-
self.state = .end
916+
self.state = .endOrError
902917
self.task.releaseAssociatedConnection(delegateType: Delegate.self, closing: self.closing).whenSuccess {
903918
self.redirectHandler?.redirect(status: head.status, to: redirectURL, promise: self.task.promise)
904919
}
905920
default:
906-
self.state = .end
921+
self.state = .endOrError
907922
self.callOutToDelegate(promise: self.task.promise, self.delegate.didFinishRequest)
908923
}
909924
}
@@ -918,14 +933,14 @@ extension TaskHandler: ChannelDuplexHandler {
918933
context.read()
919934
}
920935
case .failure(let error):
921-
self.state = .end
936+
self.state = .endOrError
922937
self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError)
923938
}
924939
}
925940

926941
func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
927942
if (event as? IdleStateHandler.IdleStateEvent) == .read {
928-
self.state = .end
943+
self.state = .endOrError
929944
let error = HTTPClientError.readTimeout
930945
self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError)
931946
} else {
@@ -935,7 +950,7 @@ extension TaskHandler: ChannelDuplexHandler {
935950

936951
func triggerUserOutboundEvent(context: ChannelHandlerContext, event: Any, promise: EventLoopPromise<Void>?) {
937952
if (event as? TaskCancelEvent) != nil {
938-
self.state = .end
953+
self.state = .endOrError
939954
let error = HTTPClientError.cancelled
940955
self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError)
941956
promise?.succeed(())
@@ -946,10 +961,10 @@ extension TaskHandler: ChannelDuplexHandler {
946961

947962
func channelInactive(context: ChannelHandlerContext) {
948963
switch self.state {
949-
case .end:
964+
case .endOrError:
950965
break
951966
case .body, .head, .idle, .redirected, .sent:
952-
self.state = .end
967+
self.state = .endOrError
953968
let error = HTTPClientError.remoteConnectionClosed
954969
self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError)
955970
}
@@ -960,7 +975,7 @@ extension TaskHandler: ChannelDuplexHandler {
960975
switch error {
961976
case NIOSSLError.uncleanShutdown:
962977
switch self.state {
963-
case .end:
978+
case .endOrError:
964979
/// Some HTTP Servers can 'forget' to respond with CloseNotify when client is closing connection,
965980
/// this could lead to incomplete SSL shutdown. But since request is already processed, we can ignore this error.
966981
break
@@ -969,11 +984,11 @@ extension TaskHandler: ChannelDuplexHandler {
969984
/// We can also ignore this error like `.end`.
970985
break
971986
default:
972-
self.state = .end
987+
self.state = .endOrError
973988
self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError)
974989
}
975990
default:
976-
self.state = .end
991+
self.state = .endOrError
977992
self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError)
978993
}
979994
}

Diff for: Tests/AsyncHTTPClientTests/HTTPClientInternalTests.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -949,8 +949,8 @@ class HTTPClientInternalTests: XCTestCase {
949949

950950
defer {
951951
XCTAssertNoThrow(try client.syncShutdown())
952-
XCTAssertNoThrow(try elg.syncShutdownGracefully())
953952
XCTAssertNoThrow(try httpBin.shutdown())
953+
XCTAssertNoThrow(try elg.syncShutdownGracefully())
954954
}
955955

956956
let request = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)//get")

Diff for: Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift

+20-12
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ internal final class HTTPBin {
188188
let group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
189189
let serverChannel: Channel
190190
let isShutdown: NIOAtomic<Bool> = .makeAtomic(value: false)
191+
var connections: NIOAtomic<Int>
191192
var connectionCount: NIOAtomic<Int> = .makeAtomic(value: 0)
192193
private let activeConnCounterHandler: CountActiveConnectionsHandler
193194
var activeConnections: Int {
@@ -233,6 +234,9 @@ internal final class HTTPBin {
233234
let activeConnCounterHandler = CountActiveConnectionsHandler()
234235
self.activeConnCounterHandler = activeConnCounterHandler
235236

237+
let connections = NIOAtomic.makeAtomic(value: 0)
238+
self.connections = connections
239+
236240
self.serverChannel = try! ServerBootstrap(group: self.group)
237241
.serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
238242
.serverChannelInitializer { channel in
@@ -261,10 +265,10 @@ internal final class HTTPBin {
261265
}.flatMap {
262266
if ssl {
263267
return HTTPBin.configureTLS(channel: channel).flatMap {
264-
channel.pipeline.addHandler(HttpBinHandler(channelPromise: channelPromise, maxChannelAge: maxChannelAge))
268+
channel.pipeline.addHandler(HttpBinHandler(channelPromise: channelPromise, maxChannelAge: maxChannelAge, connectionId: connections.add(1)))
265269
}
266270
} else {
267-
return channel.pipeline.addHandler(HttpBinHandler(channelPromise: channelPromise, maxChannelAge: maxChannelAge))
271+
return channel.pipeline.addHandler(HttpBinHandler(channelPromise: channelPromise, maxChannelAge: maxChannelAge, connectionId: connections.add(1)))
268272
}
269273
}
270274
}
@@ -357,9 +361,6 @@ internal struct HTTPResponseBuilder {
357361
}
358362
}
359363

360-
let globalRequestCounter = NIOAtomic<Int>.makeAtomic(value: 0)
361-
let globalConnectionCounter = NIOAtomic<Int>.makeAtomic(value: 0)
362-
363364
internal struct RequestInfo: Codable {
364365
var data: String
365366
var requestNumber: Int
@@ -378,13 +379,13 @@ internal final class HttpBinHandler: ChannelInboundHandler {
378379
let maxChannelAge: TimeAmount?
379380
var shouldClose = false
380381
var isServingRequest = false
381-
let myConnectionNumber: Int
382-
var currentRequestNumber: Int = -1
382+
let connectionId: Int
383+
var requestId: Int = 0
383384

384-
init(channelPromise: EventLoopPromise<Channel>? = nil, maxChannelAge: TimeAmount? = nil) {
385+
init(channelPromise: EventLoopPromise<Channel>? = nil, maxChannelAge: TimeAmount? = nil, connectionId: Int) {
385386
self.channelPromise = channelPromise
386387
self.maxChannelAge = maxChannelAge
387-
self.myConnectionNumber = globalConnectionCounter.add(1)
388+
self.connectionId = connectionId
388389
}
389390

390391
func handlerAdded(context: ChannelHandlerContext) {
@@ -424,7 +425,7 @@ internal final class HttpBinHandler: ChannelInboundHandler {
424425
switch self.unwrapInboundIn(data) {
425426
case .head(let req):
426427
self.responseHeaders = HTTPHeaders()
427-
self.currentRequestNumber = globalRequestCounter.add(1)
428+
self.requestId += 1
428429
self.parseAndSetOptions(from: req)
429430
let urlComponents = URLComponents(string: req.uri)!
430431
switch urlComponents.percentEncodedPath {
@@ -552,8 +553,15 @@ internal final class HttpBinHandler: ChannelInboundHandler {
552553
context.write(wrapOutboundOut(.head(response.head)), promise: nil)
553554
if let body = response.body {
554555
let requestInfo = RequestInfo(data: String(buffer: body),
555-
requestNumber: self.currentRequestNumber,
556-
connectionNumber: self.myConnectionNumber)
556+
requestNumber: self.requestId,
557+
connectionNumber: self.connectionId)
558+
let responseBody = try! JSONEncoder().encodeAsByteBuffer(requestInfo,
559+
allocator: context.channel.allocator)
560+
context.write(wrapOutboundOut(.body(.byteBuffer(responseBody))), promise: nil)
561+
} else {
562+
let requestInfo = RequestInfo(data: "",
563+
requestNumber: self.requestId,
564+
connectionNumber: self.connectionId)
557565
let responseBody = try! JSONEncoder().encodeAsByteBuffer(requestInfo,
558566
allocator: context.channel.allocator)
559567
context.write(wrapOutboundOut(.body(.byteBuffer(responseBody))), promise: nil)

Diff for: Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift

+2
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ extension HTTPClientTests {
110110
("testAllMethodsLog", testAllMethodsLog),
111111
("testClosingIdleConnectionsInPoolLogsInTheBackground", testClosingIdleConnectionsInPoolLogsInTheBackground),
112112
("testDelegateCallinsTolerateRandomEL", testDelegateCallinsTolerateRandomEL),
113+
("testContentLengthTooLongFails", testContentLengthTooLongFails),
114+
("testContentLengthTooShortFails", testContentLengthTooShortFails),
113115
]
114116
}
115117
}

Diff for: Tests/AsyncHTTPClientTests/HTTPClientTests.swift

+62-4
Original file line numberDiff line numberDiff line change
@@ -1713,7 +1713,8 @@ class HTTPClientTests: XCTestCase {
17131713

17141714
// req 1 and 2 cannot share the same connection (close header)
17151715
XCTAssertEqual(stats1.connectionNumber + 1, stats2.connectionNumber)
1716-
XCTAssertEqual(stats1.requestNumber + 1, stats2.requestNumber)
1716+
XCTAssertEqual(stats1.requestNumber, 1)
1717+
XCTAssertEqual(stats2.requestNumber, 1)
17171718

17181719
// req 2 and 3 should share the same connection (keep-alive is default)
17191720
XCTAssertEqual(stats2.requestNumber + 1, stats3.requestNumber)
@@ -1742,7 +1743,8 @@ class HTTPClientTests: XCTestCase {
17421743

17431744
// req 1 and 2 cannot share the same connection (close header)
17441745
XCTAssertEqual(stats1.connectionNumber + 1, stats2.connectionNumber)
1745-
XCTAssertEqual(stats1.requestNumber + 1, stats2.requestNumber)
1746+
XCTAssertEqual(stats1.requestNumber, 1)
1747+
XCTAssertEqual(stats2.requestNumber, 1)
17461748

17471749
// req 2 and 3 should share the same connection (keep-alive is default)
17481750
XCTAssertEqual(stats2.requestNumber + 1, stats3.requestNumber)
@@ -1773,7 +1775,7 @@ class HTTPClientTests: XCTestCase {
17731775

17741776
// req 1 and 2 cannot share the same connection (close header)
17751777
XCTAssertEqual(stats1.connectionNumber + 1, stats2.connectionNumber)
1776-
XCTAssertEqual(stats1.requestNumber + 1, stats2.requestNumber)
1778+
XCTAssertEqual(stats2.requestNumber, 1)
17771779

17781780
// req 2 and 3 should share the same connection (keep-alive is default)
17791781
XCTAssertEqual(stats2.requestNumber + 1, stats3.requestNumber)
@@ -1805,7 +1807,7 @@ class HTTPClientTests: XCTestCase {
18051807

18061808
// req 1 and 2 cannot share the same connection (close header)
18071809
XCTAssertEqual(stats1.connectionNumber + 1, stats2.connectionNumber)
1808-
XCTAssertEqual(stats1.requestNumber + 1, stats2.requestNumber)
1810+
XCTAssertEqual(stats2.requestNumber, 1)
18091811

18101812
// req 2 and 3 should share the same connection (keep-alive is default)
18111813
XCTAssertEqual(stats2.requestNumber + 1, stats3.requestNumber)
@@ -2036,6 +2038,7 @@ class HTTPClientTests: XCTestCase {
20362038
defer {
20372039
XCTAssertNoThrow(try httpClient.syncShutdown())
20382040
XCTAssertNoThrow(try httpServer.stop())
2041+
XCTAssertNoThrow(try elg.syncShutdownGracefully())
20392042
}
20402043

20412044
let delegate = TestDelegate(eventLoop: second)
@@ -2051,4 +2054,59 @@ class HTTPClientTests: XCTestCase {
20512054

20522055
XCTAssertNoThrow(try future.wait())
20532056
}
2057+
2058+
func testContentLengthTooLongFails() throws {
2059+
let url = self.defaultHTTPBinURLPrefix + "/post"
2060+
XCTAssertThrowsError(
2061+
try self.defaultClient.execute(request:
2062+
Request(url: url,
2063+
body: .stream(length: 10) { streamWriter in
2064+
let promise = self.defaultClient.eventLoopGroup.next().makePromise(of: Void.self)
2065+
DispatchQueue(label: "content-length-test").async {
2066+
streamWriter.write(.byteBuffer(ByteBuffer(string: "1"))).cascade(to: promise)
2067+
}
2068+
return promise.futureResult
2069+
})).wait()) { error in
2070+
XCTAssertEqual(error as! HTTPClientError, HTTPClientError.bodyLengthMismatch)
2071+
}
2072+
// Quickly try another request and check that it works.
2073+
let response = try self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "get").wait()
2074+
guard var body = response.body else {
2075+
XCTFail("Body missing: \(response)")
2076+
return
2077+
}
2078+
guard let info = try body.readJSONDecodable(RequestInfo.self, length: body.readableBytes) else {
2079+
XCTFail("Cannot parse body: \(body.readableBytesView.map { $0 })")
2080+
return
2081+
}
2082+
XCTAssertEqual(info.connectionNumber, 1)
2083+
XCTAssertEqual(info.requestNumber, 1)
2084+
}
2085+
2086+
// currently gets stuck because of #250 the server just never replies
2087+
func testContentLengthTooShortFails() throws {
2088+
let url = self.defaultHTTPBinURLPrefix + "/post"
2089+
let tooLong = "XBAD BAD BAD NOT HTTP/1.1\r\n\r\n"
2090+
XCTAssertThrowsError(
2091+
try self.defaultClient.execute(request:
2092+
Request(url: url,
2093+
body: .stream(length: 1) { streamWriter in
2094+
streamWriter.write(.byteBuffer(ByteBuffer(string: tooLong)))
2095+
})).wait()) { error in
2096+
XCTAssertEqual(error as! HTTPClientError, HTTPClientError.bodyLengthMismatch)
2097+
}
2098+
// Quickly try another request and check that it works. If we by accident wrote some extra bytes into the
2099+
// stream (and reuse the connection) that could cause problems.
2100+
let response = try self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "get").wait()
2101+
guard var body = response.body else {
2102+
XCTFail("Body missing: \(response)")
2103+
return
2104+
}
2105+
guard let info = try body.readJSONDecodable(RequestInfo.self, length: body.readableBytes) else {
2106+
XCTFail("Cannot parse body: \(body.readableBytesView.map { $0 })")
2107+
return
2108+
}
2109+
XCTAssertEqual(info.connectionNumber, 1)
2110+
XCTAssertEqual(info.requestNumber, 1)
2111+
}
20542112
}

Diff for: Tests/AsyncHTTPClientTests/RequestValidationTests+XCTest.swift

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ extension RequestValidationTests {
3333
("testGET_HEAD_DELETE_CONNECTRequestCanHaveBody", testGET_HEAD_DELETE_CONNECTRequestCanHaveBody),
3434
("testInvalidHeaderFieldNames", testInvalidHeaderFieldNames),
3535
("testValidHeaderFieldNames", testValidHeaderFieldNames),
36+
("testMultipleContentLengthOnNilStreamLength", testMultipleContentLengthOnNilStreamLength),
3637
]
3738
}
3839
}

0 commit comments

Comments
 (0)