diff --git a/Sources/GRPC/ConnectionBackoff.swift b/Sources/GRPC/ConnectionBackoff.swift index 561fc23c6..f6f4376c1 100644 --- a/Sources/GRPC/ConnectionBackoff.swift +++ b/Sources/GRPC/ConnectionBackoff.swift @@ -99,7 +99,7 @@ public struct ConnectionBackoff: Sequence, Sendable { /// An iterator for ``ConnectionBackoff``. public class ConnectionBackoffIterator: IteratorProtocol { - public typealias Element = (timeout: TimeInterval, backoff: TimeInterval) + public typealias Element = (timeout: TimeInterval, backoff: TimeInterval?) /// Creates a new connection backoff iterator with the given configuration. public init(connectionBackoff: ConnectionBackoff) { @@ -135,7 +135,12 @@ public class ConnectionBackoffIterator: IteratorProtocol { case let .limited(limit) where limit > 0: self.connectionBackoff.retries.limit = .limited(limit - 1) - // limit must be <= 0, no new element. + // limit is reached, return an element with only a timeout. + case let .limited(limit) where limit == 0: + self.connectionBackoff.retries.limit = .limited(limit - 1) + return self.makeElement(backoff: nil) + + // limit must be < 0, no new element. case .limited: return nil } @@ -159,8 +164,8 @@ public class ConnectionBackoffIterator: IteratorProtocol { /// Make a timeout-backoff pair from the given backoff. The timeout is the `max` of the backoff /// and `connectionBackoff.minimumConnectionTimeout`. - private func makeElement(backoff: TimeInterval) -> Element { - let timeout = max(backoff, self.connectionBackoff.minimumConnectionTimeout) + private func makeElement(backoff: TimeInterval?) -> Element { + let timeout = max(backoff ?? self.unjitteredBackoff, self.connectionBackoff.minimumConnectionTimeout) return (timeout, backoff) } diff --git a/Sources/GRPC/ConnectionManager.swift b/Sources/GRPC/ConnectionManager.swift index 35ee41101..f415224a8 100644 --- a/Sources/GRPC/ConnectionManager.swift +++ b/Sources/GRPC/ConnectionManager.swift @@ -997,7 +997,10 @@ extension ConnectionManager { } // Should we reconnect if the candidate channel fails? - let reconnect: Reconnect = timeoutAndBackoff.map { .after($0.backoff) } ?? .none + var reconnect = Reconnect.none + if let backoff = timeoutAndBackoff?.backoff { + reconnect = Reconnect.after(backoff) + } let connecting = ConnectingState( backoffIterator: backoffIterator, reconnect: reconnect, diff --git a/Tests/GRPCTests/ConnectionBackoffTests.swift b/Tests/GRPCTests/ConnectionBackoffTests.swift index 6308e8afd..23bf0dca2 100644 --- a/Tests/GRPCTests/ConnectionBackoffTests.swift +++ b/Tests/GRPCTests/ConnectionBackoffTests.swift @@ -43,7 +43,7 @@ class ConnectionBackoffTests: GRPCTestCase { pow(self.backoff.initialBackoff * self.backoff.multiplier, Double(i)), self.backoff.maximumBackoff ) - XCTAssertEqual(expected, backoff, accuracy: 1e-6) + XCTAssertEqual(expected, backoff!, accuracy: 1e-6) } } @@ -55,7 +55,7 @@ class ConnectionBackoffTests: GRPCTestCase { ) let halfJitterRange = self.backoff.jitter * unjittered let jitteredRange = (unjittered - halfJitterRange) ... (unjittered + halfJitterRange) - XCTAssert(jitteredRange.contains(timeoutAndBackoff.backoff)) + XCTAssert(jitteredRange.contains(timeoutAndBackoff.backoff!)) } } @@ -65,7 +65,7 @@ class ConnectionBackoffTests: GRPCTestCase { self.backoff.jitter = 0.0 for backoff in self.backoff.prefix(100).map({ $0.backoff }) { - XCTAssertLessThanOrEqual(backoff, self.backoff.maximumBackoff) + XCTAssertLessThanOrEqual(backoff!, self.backoff.maximumBackoff) } } @@ -79,19 +79,19 @@ class ConnectionBackoffTests: GRPCTestCase { for limit in [1, 3, 5] { let backoff = ConnectionBackoff(retries: .upTo(limit)) let values = Array(backoff) - XCTAssertEqual(values.count, limit) + XCTAssertEqual(values.count, limit+1) } } func testConnectionBackoffWhenLimitedToZeroRetries() { let backoff = ConnectionBackoff(retries: .upTo(0)) let values = Array(backoff) - XCTAssertTrue(values.isEmpty) + XCTAssertEqual(values.count, 1) } func testConnectionBackoffWithNoRetries() { let backoff = ConnectionBackoff(retries: .none) let values = Array(backoff) - XCTAssertTrue(values.isEmpty) + XCTAssertEqual(values.count, 1) } }