Skip to content

Commit aac0a82

Browse files
More checks for task cancellation and tests (#44)
### Motivation In our fallback, buffered implementation, we did not use a task cancellation handler so were not proactively cancelling the URLSession task when the Swift concurrency task was cancelled. Additionally, while we _did_ have a task cancellation handler in the streaming implementation, so the URLSession task would be cancelled, we were not actively checking for task cancellation as often as we could. ### Modifications - Added more cooperative task cancellation. - Added tests for both implementations that when the parent task for the client request is cancelled that we get something sensible. Note that in some cases, the request will succeed. In the cases where the request fails, it will surface as a `ClientError` to the user where the `underlyingError` is either `Swift.CancellationError` or `URLError` with `code == .cancelled`. ### Result More cooperative task and URLSession task cancellation and more thorough tests. ### Test Plan Added unit tests.
1 parent 144464e commit aac0a82

File tree

4 files changed

+282
-19
lines changed

4 files changed

+282
-19
lines changed

Sources/OpenAPIURLSession/URLSessionBidirectionalStreaming/URLSession+Extensions.swift

+3
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import Foundation
3232
task = dataTask(with: urlRequest)
3333
}
3434
return try await withTaskCancellationHandler {
35+
try Task.checkCancellation()
3536
let delegate = BidirectionalStreamingURLSessionDelegate(
3637
requestBody: requestBody,
3738
requestStreamBufferSize: requestStreamBufferSize,
@@ -47,8 +48,10 @@ import Foundation
4748
length: .init(from: response),
4849
iterationBehavior: .single
4950
)
51+
try Task.checkCancellation()
5052
return (try HTTPResponse(response), responseBody)
5153
} onCancel: {
54+
debug("Concurrency task cancelled, cancelling URLSession task.")
5255
task.cancel()
5356
}
5457
}

Sources/OpenAPIURLSession/URLSessionTransport.swift

+35-15
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import class Foundation.FileHandle
2424
#if canImport(FoundationNetworking)
2525
@preconcurrency import struct FoundationNetworking.URLRequest
2626
import class FoundationNetworking.URLSession
27+
import class FoundationNetworking.URLSessionTask
2728
import class FoundationNetworking.URLResponse
2829
import class FoundationNetworking.HTTPURLResponse
2930
#endif
@@ -243,31 +244,50 @@ extension URLSession {
243244
func bufferedRequest(for request: HTTPRequest, baseURL: URL, requestBody: HTTPBody?) async throws -> (
244245
HTTPResponse, HTTPBody?
245246
) {
247+
try Task.checkCancellation()
246248
var urlRequest = try URLRequest(request, baseURL: baseURL)
247249
if let requestBody { urlRequest.httpBody = try await Data(collecting: requestBody, upTo: .max) }
250+
try Task.checkCancellation()
248251

249252
/// Use `dataTask(with:completionHandler:)` here because `data(for:[delegate:]) async` is only available on
250253
/// Darwin platforms newer than our minimum deployment target, and not at all on Linux.
251-
let (response, maybeResponseBodyData): (URLResponse, Data?) = try await withCheckedThrowingContinuation {
252-
continuation in
253-
let task = self.dataTask(with: urlRequest) { [urlRequest] data, response, error in
254-
if let error {
255-
continuation.resume(throwing: error)
256-
return
254+
let taskBox: LockedValueBox<URLSessionTask?> = .init(nil)
255+
return try await withTaskCancellationHandler {
256+
let (response, maybeResponseBodyData): (URLResponse, Data?) = try await withCheckedThrowingContinuation {
257+
continuation in
258+
let task = self.dataTask(with: urlRequest) { [urlRequest] data, response, error in
259+
if let error {
260+
continuation.resume(throwing: error)
261+
return
262+
}
263+
guard let response else {
264+
continuation.resume(throwing: URLSessionTransportError.noResponse(url: urlRequest.url))
265+
return
266+
}
267+
continuation.resume(with: .success((response, data)))
257268
}
258-
guard let response else {
259-
continuation.resume(throwing: URLSessionTransportError.noResponse(url: urlRequest.url))
260-
return
269+
// Swift concurrency task cancelled here.
270+
taskBox.withLockedValue { boxedTask in
271+
guard task.state == .suspended else {
272+
debug("URLSession task cannot be resumed, probably because it was cancelled by onCancel.")
273+
return
274+
}
275+
task.resume()
276+
boxedTask = task
261277
}
262-
continuation.resume(with: .success((response, data)))
263278
}
264-
task.resume()
265-
}
266279

267-
let maybeResponseBody = maybeResponseBodyData.map { data in
268-
HTTPBody(data, length: HTTPBody.Length(from: response), iterationBehavior: .multiple)
280+
let maybeResponseBody = maybeResponseBodyData.map { data in
281+
HTTPBody(data, length: HTTPBody.Length(from: response), iterationBehavior: .multiple)
282+
}
283+
return (try HTTPResponse(response), maybeResponseBody)
284+
} onCancel: {
285+
taskBox.withLockedValue { boxedTask in
286+
debug("Concurrency task cancelled, cancelling URLSession task.")
287+
boxedTask?.cancel()
288+
boxedTask = nil
289+
}
269290
}
270-
return (try HTTPResponse(response), maybeResponseBody)
271291
}
272292
}
273293

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// This source file is part of the SwiftOpenAPIGenerator open source project
4+
//
5+
// Copyright (c) 2023 Apple Inc. and the SwiftOpenAPIGenerator project authors
6+
// Licensed under Apache License v2.0
7+
//
8+
// See LICENSE.txt for license information
9+
// See CONTRIBUTORS.txt for the list of SwiftOpenAPIGenerator project authors
10+
//
11+
// SPDX-License-Identifier: Apache-2.0
12+
//
13+
//===----------------------------------------------------------------------===//
14+
#if canImport(Darwin)
15+
16+
import Foundation
17+
import HTTPTypes
18+
import NIO
19+
import OpenAPIRuntime
20+
import XCTest
21+
@testable import OpenAPIURLSession
22+
23+
enum CancellationPoint: CaseIterable {
24+
case beforeSendingHead
25+
case beforeSendingRequestBody
26+
case partwayThroughSendingRequestBody
27+
case beforeConsumingResponseBody
28+
case partwayThroughConsumingResponseBody
29+
case afterConsumingResponseBody
30+
}
31+
32+
func testTaskCancelled(_ cancellationPoint: CancellationPoint, transport: URLSessionTransport) async throws {
33+
let requestPath = "/hello/world"
34+
let requestBodyElements = ["Hello,", "world!"]
35+
let requestBodySequence = MockAsyncSequence(elementsToVend: requestBodyElements, gatingProduction: true)
36+
let requestBody = HTTPBody(
37+
requestBodySequence,
38+
length: .known(Int64(requestBodyElements.joined().lengthOfBytes(using: .utf8))),
39+
iterationBehavior: .single
40+
)
41+
42+
let responseBodyMessage = "Hey!"
43+
44+
let taskShouldCancel = XCTestExpectation(description: "Concurrency task cancelled")
45+
let taskCancelled = XCTestExpectation(description: "Concurrency task cancelled")
46+
47+
try await withThrowingTaskGroup(of: Void.self) { group in
48+
let serverPort = try await AsyncTestHTTP1Server.start(connectionTaskGroup: &group) { connectionChannel in
49+
try await connectionChannel.executeThenClose { inbound, outbound in
50+
var requestPartIterator = inbound.makeAsyncIterator()
51+
var accumulatedBody = ByteBuffer()
52+
while let requestPart = try await requestPartIterator.next() {
53+
switch requestPart {
54+
case .head(let head):
55+
XCTAssertEqual(head.uri, requestPath)
56+
XCTAssertEqual(head.method, .POST)
57+
case .body(let buffer): accumulatedBody.writeImmutableBuffer(buffer)
58+
case .end:
59+
switch cancellationPoint {
60+
case .beforeConsumingResponseBody, .partwayThroughConsumingResponseBody,
61+
.afterConsumingResponseBody:
62+
XCTAssertEqual(
63+
String(decoding: accumulatedBody.readableBytesView, as: UTF8.self),
64+
requestBodyElements.joined()
65+
)
66+
case .beforeSendingHead, .beforeSendingRequestBody, .partwayThroughSendingRequestBody: break
67+
}
68+
try await outbound.write(.head(.init(version: .http1_1, status: .ok)))
69+
try await outbound.write(.body(ByteBuffer(string: responseBodyMessage)))
70+
try await outbound.write(.end(nil))
71+
}
72+
}
73+
}
74+
}
75+
debug("Server running on 127.0.0.1:\(serverPort)")
76+
77+
let task = Task {
78+
if case .beforeSendingHead = cancellationPoint {
79+
taskShouldCancel.fulfill()
80+
await fulfillment(of: [taskCancelled])
81+
}
82+
debug("Client starting request")
83+
async let (asyncResponse, asyncResponseBody) = try await transport.send(
84+
HTTPRequest(method: .post, scheme: nil, authority: nil, path: requestPath),
85+
body: requestBody,
86+
baseURL: URL(string: "http://127.0.0.1:\(serverPort)")!,
87+
operationID: "unused"
88+
)
89+
90+
if case .beforeSendingRequestBody = cancellationPoint {
91+
taskShouldCancel.fulfill()
92+
await fulfillment(of: [taskCancelled])
93+
}
94+
95+
requestBodySequence.openGate(for: 1)
96+
97+
if case .partwayThroughSendingRequestBody = cancellationPoint {
98+
taskShouldCancel.fulfill()
99+
await fulfillment(of: [taskCancelled])
100+
}
101+
102+
requestBodySequence.openGate()
103+
104+
let (response, maybeResponseBody) = try await (asyncResponse, asyncResponseBody)
105+
106+
debug("Client received response head: \(response)")
107+
XCTAssertEqual(response.status, .ok)
108+
let responseBody = try XCTUnwrap(maybeResponseBody)
109+
110+
if case .beforeConsumingResponseBody = cancellationPoint {
111+
taskShouldCancel.fulfill()
112+
await fulfillment(of: [taskCancelled])
113+
}
114+
115+
var iterator = responseBody.makeAsyncIterator()
116+
117+
_ = try await iterator.next()
118+
119+
if case .partwayThroughConsumingResponseBody = cancellationPoint {
120+
taskShouldCancel.fulfill()
121+
await fulfillment(of: [taskCancelled])
122+
}
123+
124+
while try await iterator.next() != nil {
125+
126+
}
127+
128+
if case .afterConsumingResponseBody = cancellationPoint {
129+
taskShouldCancel.fulfill()
130+
await fulfillment(of: [taskCancelled])
131+
}
132+
133+
}
134+
135+
await fulfillment(of: [taskShouldCancel])
136+
task.cancel()
137+
taskCancelled.fulfill()
138+
139+
switch transport.configuration.implementation {
140+
case .buffering:
141+
switch cancellationPoint {
142+
case .beforeSendingHead, .beforeSendingRequestBody, .partwayThroughSendingRequestBody:
143+
await XCTAssertThrowsError(try await task.value) { error in XCTAssertTrue(error is CancellationError) }
144+
case .beforeConsumingResponseBody, .partwayThroughConsumingResponseBody, .afterConsumingResponseBody:
145+
try await task.value
146+
}
147+
case .streaming:
148+
switch cancellationPoint {
149+
case .beforeSendingHead:
150+
await XCTAssertThrowsError(try await task.value) { error in XCTAssertTrue(error is CancellationError) }
151+
case .beforeSendingRequestBody, .partwayThroughSendingRequestBody:
152+
await XCTAssertThrowsError(try await task.value) { error in
153+
guard let urlError = error as? URLError else {
154+
XCTFail()
155+
return
156+
}
157+
XCTAssertEqual(urlError.code, .cancelled)
158+
}
159+
case .beforeConsumingResponseBody, .partwayThroughConsumingResponseBody, .afterConsumingResponseBody:
160+
try await task.value
161+
}
162+
}
163+
164+
group.cancelAll()
165+
}
166+
167+
}
168+
169+
func fulfillment(
170+
of expectations: [XCTestExpectation],
171+
timeout seconds: TimeInterval = .infinity,
172+
enforceOrder enforceOrderOfFulfillment: Bool = false,
173+
file: StaticString = #file,
174+
line: UInt = #line
175+
) async {
176+
guard
177+
case .completed = await XCTWaiter.fulfillment(
178+
of: expectations,
179+
timeout: seconds,
180+
enforceOrder: enforceOrderOfFulfillment
181+
)
182+
else {
183+
XCTFail("Expectation was not fulfilled", file: file, line: line)
184+
return
185+
}
186+
}
187+
188+
extension URLSessionTransportBufferedTests {
189+
func testCancellation_beforeSendingHead() async throws {
190+
try await testTaskCancelled(.beforeSendingHead, transport: transport)
191+
}
192+
193+
func testCancellation_beforeSendingRequestBody() async throws {
194+
try await testTaskCancelled(.beforeSendingRequestBody, transport: transport)
195+
}
196+
197+
func testCancellation_partwayThroughSendingRequestBody() async throws {
198+
try await testTaskCancelled(.partwayThroughSendingRequestBody, transport: transport)
199+
}
200+
201+
func testCancellation_beforeConsumingResponseBody() async throws {
202+
try await testTaskCancelled(.beforeConsumingResponseBody, transport: transport)
203+
}
204+
205+
func testCancellation_partwayThroughConsumingResponseBody() async throws {
206+
try await testTaskCancelled(.partwayThroughConsumingResponseBody, transport: transport)
207+
}
208+
209+
func testCancellation_afterConsumingResponseBody() async throws {
210+
try await testTaskCancelled(.afterConsumingResponseBody, transport: transport)
211+
}
212+
}
213+
214+
extension URLSessionTransportStreamingTests {
215+
func testCancellation_beforeSendingHead() async throws {
216+
try await testTaskCancelled(.beforeSendingHead, transport: transport)
217+
}
218+
219+
func testCancellation_beforeSendingRequestBody() async throws {
220+
try await testTaskCancelled(.beforeSendingRequestBody, transport: transport)
221+
}
222+
223+
func testCancellation_partwayThroughSendingRequestBody() async throws {
224+
try await testTaskCancelled(.partwayThroughSendingRequestBody, transport: transport)
225+
}
226+
227+
func testCancellation_beforeConsumingResponseBody() async throws {
228+
try await testTaskCancelled(.beforeConsumingResponseBody, transport: transport)
229+
}
230+
231+
func testCancellation_partwayThroughConsumingResponseBody() async throws {
232+
try await testTaskCancelled(.partwayThroughConsumingResponseBody, transport: transport)
233+
}
234+
235+
func testCancellation_afterConsumingResponseBody() async throws {
236+
try await testTaskCancelled(.afterConsumingResponseBody, transport: transport)
237+
}
238+
}
239+
240+
#endif // canImport(Darwin)

Tests/OpenAPIURLSessionTests/URLSessionTransportTests.swift

+4-4
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class URLSessionTransportConverterTests: XCTestCase {
5656

5757
// swift-format-ignore: AllPublicDeclarationsHaveDocumentation
5858
class URLSessionTransportBufferedTests: XCTestCase {
59-
var transport: (any ClientTransport)!
59+
var transport: URLSessionTransport!
6060

6161
static override func setUp() { OpenAPIURLSession.debugLoggingEnabled = false }
6262

@@ -66,7 +66,7 @@ class URLSessionTransportBufferedTests: XCTestCase {
6666

6767
func testBasicGet() async throws { try await testHTTPBasicGet(transport: transport) }
6868

69-
func testBasicPost() async throws { try await testHTTPBasicGet(transport: transport) }
69+
func testBasicPost() async throws { try await testHTTPBasicPost(transport: transport) }
7070

7171
#if canImport(Darwin) // Only passes on Darwin because Linux doesn't replay the request body on 307.
7272
func testHTTPRedirect_multipleIterationBehavior_succeeds() async throws {
@@ -89,7 +89,7 @@ class URLSessionTransportBufferedTests: XCTestCase {
8989

9090
// swift-format-ignore: AllPublicDeclarationsHaveDocumentation
9191
class URLSessionTransportStreamingTests: XCTestCase {
92-
var transport: (any ClientTransport)!
92+
var transport: URLSessionTransport!
9393

9494
static override func setUp() { OpenAPIURLSession.debugLoggingEnabled = false }
9595

@@ -107,7 +107,7 @@ class URLSessionTransportStreamingTests: XCTestCase {
107107

108108
func testBasicGet() async throws { try await testHTTPBasicGet(transport: transport) }
109109

110-
func testBasicPost() async throws { try await testHTTPBasicGet(transport: transport) }
110+
func testBasicPost() async throws { try await testHTTPBasicPost(transport: transport) }
111111

112112
#if canImport(Darwin) // Only passes on Darwin because Linux doesn't replay the request body on 307.
113113
func testHTTPRedirect_multipleIterationBehavior_succeeds() async throws {

0 commit comments

Comments
 (0)