@@ -40,6 +40,9 @@ public actor HTTPClient {
40
40
/// Array of `HostErrors` values, which is used for applying a circuit-breaking strategy.
41
41
private var hostsErrors = [ String: HostErrors] ( )
42
42
43
+ /// Tracks all active network request tasks.
44
+ private var activeTasks : Set < Task < HTTPClient . Response , Error > > = [ ]
45
+
43
46
public init ( configuration: HTTPClientConfiguration = . init( ) , implementation: Implementation ? = nil ) {
44
47
self . configuration = configuration
45
48
self . implementation = implementation ?? URLSessionHTTPClient ( ) . execute
@@ -92,6 +95,21 @@ public actor HTTPClient {
92
95
return try await self . executeWithStrategies ( request: request, requestNumber: 0 , observabilityScope, progress)
93
96
}
94
97
98
+ /// Cancel all in flight network reqeusts.
99
+ public func cancel( deadline: DispatchTime ) async {
100
+ for task in activeTasks {
101
+ task. cancel ( )
102
+ }
103
+
104
+ // Wait for tasks to complete or timeout
105
+ while !activeTasks. isEmpty && ( deadline. distance ( to: . now( ) ) . nanoseconds ( ) ?? 0 ) > 0 {
106
+ await Task . yield ( )
107
+ }
108
+
109
+ // Clear out the active task list regardless of whether they completed or not
110
+ activeTasks. removeAll ( )
111
+ }
112
+
95
113
private func executeWithStrategies(
96
114
request: Request ,
97
115
requestNumber: Int ,
@@ -104,46 +122,57 @@ public actor HTTPClient {
104
122
throw HTTPClientError . circuitBreakerTriggered
105
123
}
106
124
107
- let response = try await self . tokenBucket. withToken {
108
- try await self . implementation ( request) { received, expected in
109
- if let max = request. options. maximumResponseSizeInBytes {
110
- guard received < max else {
111
- // It's a responsibility of the underlying client implementation to cancel the request
112
- // when this closure throws an error
113
- throw HTTPClientError . responseTooLarge ( received)
125
+ let task = Task {
126
+ let response = try await self . tokenBucket. withToken {
127
+ try Task . checkCancellation ( )
128
+
129
+ return try await self . implementation ( request) { received, expected in
130
+ if let max = request. options. maximumResponseSizeInBytes {
131
+ guard received < max else {
132
+ // It's a responsibility of the underlying client implementation to cancel the request
133
+ // when this closure throws an error
134
+ throw HTTPClientError . responseTooLarge ( received)
135
+ }
114
136
}
115
- }
116
137
117
- try progress ? ( received, expected)
138
+ try progress ? ( received, expected)
139
+ }
118
140
}
119
- }
120
141
121
- self . recordErrorIfNecessary ( response: response, request: request)
142
+ self . recordErrorIfNecessary ( response: response, request: request)
122
143
123
- // handle retry strategy
124
- if let retryDelay = self . calculateRetry (
125
- response: response,
126
- request: request,
127
- requestNumber: requestNumber
128
- ) , let retryDelayInNanoseconds = retryDelay. nanoseconds ( ) {
129
- observabilityScope? . emit ( warning: " \( request. url) failed, retrying in \( retryDelay) " )
130
- try await Task . sleep ( nanoseconds: UInt64 ( retryDelayInNanoseconds) )
131
-
132
- return try await self . executeWithStrategies (
144
+ // handle retry strategy
145
+ if let retryDelay = self . calculateRetry (
146
+ response: response,
133
147
request: request,
134
- requestNumber: requestNumber + 1 ,
135
- observabilityScope,
136
- progress
137
- )
138
- }
139
- // check for valid response codes
140
- if let validResponseCodes = request. options. validResponseCodes,
141
- !validResponseCodes. contains ( response. statusCode)
142
- {
143
- throw HTTPClientError . badResponseStatusCode ( response. statusCode)
144
- } else {
145
- return response
148
+ requestNumber: requestNumber
149
+ ) , let retryDelayInNanoseconds = retryDelay. nanoseconds ( ) {
150
+ try Task . checkCancellation ( )
151
+
152
+ observabilityScope? . emit ( warning: " \( request. url) failed, retrying in \( retryDelay) " )
153
+ try await Task . sleep ( nanoseconds: UInt64 ( retryDelayInNanoseconds) )
154
+
155
+ return try await self . executeWithStrategies (
156
+ request: request,
157
+ requestNumber: requestNumber + 1 ,
158
+ observabilityScope,
159
+ progress
160
+ )
161
+ }
162
+ // check for valid response codes
163
+ if let validResponseCodes = request. options. validResponseCodes,
164
+ !validResponseCodes. contains ( response. statusCode)
165
+ {
166
+ throw HTTPClientError . badResponseStatusCode ( response. statusCode)
167
+ } else {
168
+ return response
169
+ }
146
170
}
171
+
172
+ activeTasks. insert ( task)
173
+ defer { activeTasks. remove ( task) }
174
+
175
+ return try await task. value
147
176
}
148
177
149
178
private func calculateRetry( response: Response , request: Request , requestNumber: Int ) -> SendableTimeInterval ? {
@@ -258,4 +287,26 @@ extension HTTPClient {
258
287
Request ( method: . delete, url: url, headers: headers, body: nil , options: options)
259
288
)
260
289
}
290
+
291
+ public func download(
292
+ _ url: URL ,
293
+ headers: HTTPClientHeaders = . init( ) ,
294
+ options: Request . Options = . init( ) ,
295
+ progressHandler: ProgressHandler ? = nil ,
296
+ fileSystem: FileSystem ,
297
+ destination: AbsolutePath ,
298
+ observabilityScope: ObservabilityScope ? = . none
299
+ ) async throws -> Response {
300
+ try await self . execute (
301
+ Request (
302
+ kind: . download( fileSystem: fileSystem, destination: destination) ,
303
+ url: url,
304
+ headers: headers,
305
+ body: nil ,
306
+ options: options
307
+ ) ,
308
+ observabilityScope: observabilityScope,
309
+ progress: progressHandler
310
+ )
311
+ }
261
312
}
0 commit comments