Skip to content

Commit 95307ba

Browse files
LarsPetersHHLars Peters
and
Lars Peters
authored
Bug/502 crash thread safety fix (#95)
### Motivation Fixes apple/swift-openapi-generator#502 - Ensure thread safety of `HTTPBody.collect(upTo)`. - `makeAsyncIterator()`: Instead of crashing, return AsyncSequence which throws `TooManyIterationsError` thereby honoring the contract for `IterationBehavior.single` (HTTPBody, MultipartBody) ### Modifications - HTTPBody, MultipartBody: `makeAsyncIterator()`: removed `try!`, catch error and create a sequence which throws the error on iteration. - This removed the need for `try checkIfCanCreateIterator()` in `HTTPBody.collect(upTo)`. **Note**: This creates a small change in behavior: There may be a `TooManyBytesError` thrown before the check for `iterationBehavior`. This approach uses the simplest code, IMO. If we want to keep that `iterationBehavior` is checked first and only after that for the length, then the code needs to be more complex. - Removed `try checkIfCanCreateIterator()` in both classes (only used in `HTTPBody`). ### Result - No intentional crash in `makeAsyncIterator()` anymore. - Tests supplied as example in apple/swift-openapi-generator#502 succeed. ### Test Plan - Added check in `Test_Body.testIterationBehavior_single()` to ensure that using `makeAsyncIterator()` directly yields the expected error. - Added tests to check iteration behavior of `MultipartBody`. --------- Co-authored-by: Lars Peters <[email protected]>
1 parent 7f86e4a commit 95307ba

File tree

4 files changed

+71
-30
lines changed

4 files changed

+71
-30
lines changed

Sources/OpenAPIRuntime/Interface/HTTPBody.swift

+8-17
Original file line numberDiff line numberDiff line change
@@ -159,16 +159,6 @@ public final class HTTPBody: @unchecked Sendable {
159159
return locked_iteratorCreated
160160
}
161161

162-
/// Verifying that creating another iterator is allowed based on
163-
/// the values of `iterationBehavior` and `locked_iteratorCreated`.
164-
/// - Throws: If another iterator is not allowed to be created.
165-
private func checkIfCanCreateIterator() throws {
166-
lock.lock()
167-
defer { lock.unlock() }
168-
guard iterationBehavior == .single else { return }
169-
if locked_iteratorCreated { throw TooManyIterationsError() }
170-
}
171-
172162
/// Tries to mark an iterator as created, verifying that it is allowed
173163
/// based on the values of `iterationBehavior` and `locked_iteratorCreated`.
174164
/// - Throws: If another iterator is not allowed to be created.
@@ -341,10 +331,12 @@ extension HTTPBody: AsyncSequence {
341331
/// Creates and returns an asynchronous iterator
342332
///
343333
/// - Returns: An asynchronous iterator for byte chunks.
334+
/// - Note: The returned sequence throws an error if no further iterations are allowed. See ``IterationBehavior``.
344335
public func makeAsyncIterator() -> AsyncIterator {
345-
// The crash on error is intentional here.
346-
try! tryToMarkIteratorCreated()
347-
return .init(sequence.makeAsyncIterator())
336+
do {
337+
try tryToMarkIteratorCreated()
338+
return .init(sequence.makeAsyncIterator())
339+
} catch { return .init(throwing: error) }
348340
}
349341
}
350342

@@ -381,10 +373,6 @@ extension HTTPBody {
381373
/// than `maxBytes`.
382374
/// - Returns: A byte chunk containing all the accumulated bytes.
383375
fileprivate func collect(upTo maxBytes: Int) async throws -> ByteChunk {
384-
385-
// Check that we're allowed to iterate again.
386-
try checkIfCanCreateIterator()
387-
388376
// If the length is known, verify it's within the limit.
389377
if case .known(let knownBytes) = length {
390378
guard knownBytes <= maxBytes else { throw TooManyBytesError(maxBytes: maxBytes) }
@@ -563,6 +551,9 @@ extension HTTPBody {
563551
var iterator = iterator
564552
self.produceNext = { try await iterator.next() }
565553
}
554+
/// Creates an iterator throwing the given error when iterated.
555+
/// - Parameter error: The error to throw on iteration.
556+
fileprivate init(throwing error: any Error) { self.produceNext = { throw error } }
566557

567558
/// Advances the iterator to the next element and returns it asynchronously.
568559
///

Sources/OpenAPIRuntime/Multipart/MultipartPublicTypes.swift

+9-13
Original file line numberDiff line numberDiff line change
@@ -209,16 +209,6 @@ public final class MultipartBody<Part: Sendable>: @unchecked Sendable {
209209
var errorDescription: String? { description }
210210
}
211211

212-
/// Verifying that creating another iterator is allowed based on the values of `iterationBehavior`
213-
/// and `locked_iteratorCreated`.
214-
/// - Throws: If another iterator is not allowed to be created.
215-
internal func checkIfCanCreateIterator() throws {
216-
lock.lock()
217-
defer { lock.unlock() }
218-
guard iterationBehavior == .single else { return }
219-
if locked_iteratorCreated { throw TooManyIterationsError() }
220-
}
221-
222212
/// Tries to mark an iterator as created, verifying that it is allowed based on the values
223213
/// of `iterationBehavior` and `locked_iteratorCreated`.
224214
/// - Throws: If another iterator is not allowed to be created.
@@ -331,10 +321,12 @@ extension MultipartBody: AsyncSequence {
331321
/// Creates and returns an asynchronous iterator
332322
///
333323
/// - Returns: An asynchronous iterator for parts.
324+
/// - Note: The returned sequence throws an error if no further iterations are allowed. See ``IterationBehavior``.
334325
public func makeAsyncIterator() -> AsyncIterator {
335-
// The crash on error is intentional here.
336-
try! tryToMarkIteratorCreated()
337-
return .init(sequence.makeAsyncIterator())
326+
do {
327+
try tryToMarkIteratorCreated()
328+
return .init(sequence.makeAsyncIterator())
329+
} catch { return .init(throwing: error) }
338330
}
339331
}
340332

@@ -355,6 +347,10 @@ extension MultipartBody {
355347
self.produceNext = { try await iterator.next() }
356348
}
357349

350+
/// Creates an iterator throwing the given error when iterated.
351+
/// - Parameter error: The error to throw on iteration.
352+
fileprivate init(throwing error: any Error) { self.produceNext = { throw error } }
353+
358354
/// Advances the iterator to the next element and returns it asynchronously.
359355
///
360356
/// - Returns: The next element in the sequence, or `nil` if there are no more elements.

Tests/OpenAPIRuntimeTests/Interface/Test_HTTPBody.swift

+5
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,11 @@ final class Test_Body: Test_Runtime {
173173
_ = try await String(collecting: body, upTo: .max)
174174
XCTFail("Expected an error to be thrown")
175175
} catch {}
176+
177+
do {
178+
for try await _ in body {}
179+
XCTFail("Expected an error to be thrown")
180+
} catch {}
176181
}
177182

178183
func testIterationBehavior_multiple() async throws {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// This source file is part of the SwiftOpenAPIGenerator open source project
4+
//
5+
// Copyright (c) 2023 Apple Inc. and the SwiftOpenAPIGenerator project authors
6+
// Licensed under Apache License v2.0
7+
//
8+
// See LICENSE.txt for license information
9+
// See CONTRIBUTORS.txt for the list of SwiftOpenAPIGenerator project authors
10+
//
11+
// SPDX-License-Identifier: Apache-2.0
12+
//
13+
//===----------------------------------------------------------------------===//
14+
import XCTest
15+
@_spi(Generated) @testable import OpenAPIRuntime
16+
import Foundation
17+
18+
final class Test_MultipartBody: XCTestCase {
19+
20+
func testIterationBehavior_single() async throws {
21+
let sourceSequence = (0..<Int.random(in: 2..<10)).map { _ in UUID().uuidString }
22+
let body = MultipartBody(sourceSequence, iterationBehavior: .single)
23+
24+
XCTAssertFalse(body.testing_iteratorCreated)
25+
26+
let iterated = try await body.reduce("") { $0 + $1 }
27+
XCTAssertEqual(iterated, sourceSequence.joined())
28+
29+
XCTAssertTrue(body.testing_iteratorCreated)
30+
31+
do {
32+
for try await _ in body {}
33+
XCTFail("Expected an error to be thrown")
34+
} catch {}
35+
}
36+
37+
func testIterationBehavior_multiple() async throws {
38+
let sourceSequence = (0..<Int.random(in: 2..<10)).map { _ in UUID().uuidString }
39+
let body = MultipartBody(sourceSequence, iterationBehavior: .multiple)
40+
41+
XCTAssertFalse(body.testing_iteratorCreated)
42+
for _ in 0..<2 {
43+
let iterated = try await body.reduce("") { $0 + $1 }
44+
XCTAssertEqual(iterated, sourceSequence.joined())
45+
XCTAssertTrue(body.testing_iteratorCreated)
46+
}
47+
}
48+
49+
}

0 commit comments

Comments
 (0)