|
| 1 | +//===----------------------------------------------------------------------===// |
| 2 | +// |
| 3 | +// This source file is part of the SwiftNIO open source project |
| 4 | +// |
| 5 | +// Copyright (c) 2024 Apple Inc. and the SwiftNIO project authors |
| 6 | +// Licensed under Apache License v2.0 |
| 7 | +// |
| 8 | +// See LICENSE.txt for license information |
| 9 | +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors |
| 10 | +// |
| 11 | +// SPDX-License-Identifier: Apache-2.0 |
| 12 | +// |
| 13 | +//===----------------------------------------------------------------------===// |
| 14 | + |
| 15 | +import HTTPTypes |
| 16 | +import NIOCore |
| 17 | +import NIOHTTPTypes |
| 18 | + |
| 19 | +/// HTTP request handler sending a configurable stream of zeroes. Uses HTTPTypes request/response parts. |
| 20 | +public final class HTTPDrippingDownloadHandler: ChannelDuplexHandler { |
| 21 | + public typealias InboundIn = HTTPRequestPart |
| 22 | + public typealias OutboundOut = HTTPResponsePart |
| 23 | + public typealias OutboundIn = Never |
| 24 | + |
| 25 | + // Predefine buffer to reuse over and over again when sending chunks to requester. NIO allows |
| 26 | + // us to give it reference counted buffers. Reusing like this allows us to avoid allocations. |
| 27 | + static let downloadBodyChunk = ByteBuffer(repeating: 0, count: 65536) |
| 28 | + |
| 29 | + private var frequency: TimeAmount |
| 30 | + private var size: Int |
| 31 | + private var count: Int |
| 32 | + private var delay: TimeAmount |
| 33 | + private var code: HTTPResponse.Status |
| 34 | + |
| 35 | + private enum Phase { |
| 36 | + /// We haven't gotten the request head - nothing to respond to |
| 37 | + case waitingOnHead |
| 38 | + /// We got the request head and are delaying the response |
| 39 | + case delayingResponse |
| 40 | + /// We're dripping response chunks to the peer, tracking how many chunks we have left |
| 41 | + case dripping(DrippingState) |
| 42 | + /// We either sent everything to the client or the request ended short |
| 43 | + case done |
| 44 | + } |
| 45 | + |
| 46 | + private struct DrippingState { |
| 47 | + var chunksLeft: Int |
| 48 | + var currentChunkBytesLeft: Int |
| 49 | + } |
| 50 | + |
| 51 | + private var phase = Phase.waitingOnHead |
| 52 | + private var scheduled: Scheduled<Void>? |
| 53 | + private var scheduledCallbackHandler: HTTPDrippingDownloadHandlerScheduledCallbackHandler? |
| 54 | + private var pendingRead = false |
| 55 | + private var pendingWrite = false |
| 56 | + private var activelyWritingChunk = false |
| 57 | + |
| 58 | + /// Initializes an `HTTPDrippingDownloadHandler`. |
| 59 | + /// - Parameters: |
| 60 | + /// - count: How many chunks should be sent. Note that the underlying HTTP |
| 61 | + /// stack may split or combine these chunks into data frames as |
| 62 | + /// they see fit. |
| 63 | + /// - size: How large each chunk should be |
| 64 | + /// - frequency: How much time to wait between sending each chunk |
| 65 | + /// - delay: How much time to wait before sending the first chunk |
| 66 | + /// - code: What HTTP status code to send |
| 67 | + public init( |
| 68 | + count: Int = 0, |
| 69 | + size: Int = 0, |
| 70 | + frequency: TimeAmount = .zero, |
| 71 | + delay: TimeAmount = .zero, |
| 72 | + code: HTTPResponse.Status = .ok |
| 73 | + ) { |
| 74 | + self.frequency = frequency |
| 75 | + self.size = size |
| 76 | + self.count = count |
| 77 | + self.delay = delay |
| 78 | + self.code = code |
| 79 | + } |
| 80 | + |
| 81 | + public convenience init?(queryArgsString: Substring.UTF8View) { |
| 82 | + self.init() |
| 83 | + |
| 84 | + let pairs = queryArgsString.split(separator: UInt8(ascii: "&")) |
| 85 | + for pair in pairs { |
| 86 | + var pairParts = pair.split(separator: UInt8(ascii: "="), maxSplits: 1).makeIterator() |
| 87 | + guard let first = pairParts.next(), let second = pairParts.next() else { |
| 88 | + continue |
| 89 | + } |
| 90 | + |
| 91 | + guard let secondNum = Int(Substring(second)) else { |
| 92 | + return nil |
| 93 | + } |
| 94 | + |
| 95 | + switch Substring(first) { |
| 96 | + case "frequency": |
| 97 | + self.frequency = .seconds(Int64(secondNum)) |
| 98 | + case "size": |
| 99 | + self.size = secondNum |
| 100 | + case "count": |
| 101 | + self.count = secondNum |
| 102 | + case "delay": |
| 103 | + self.delay = .seconds(Int64(secondNum)) |
| 104 | + case "code": |
| 105 | + self.code = HTTPResponse.Status(code: secondNum) |
| 106 | + default: |
| 107 | + continue |
| 108 | + } |
| 109 | + } |
| 110 | + } |
| 111 | + |
| 112 | + public func channelRead(context: ChannelHandlerContext, data: NIOAny) { |
| 113 | + let part = self.unwrapInboundIn(data) |
| 114 | + |
| 115 | + switch part { |
| 116 | + case .head: |
| 117 | + self.phase = .delayingResponse |
| 118 | + |
| 119 | + if self.delay == TimeAmount.zero { |
| 120 | + // If no delay, we might as well start responding immediately |
| 121 | + self.onResponseDelayCompleted(context: context) |
| 122 | + } else { |
| 123 | + let this = NIOLoopBound(self, eventLoop: context.eventLoop) |
| 124 | + let loopBoundContext = NIOLoopBound(context, eventLoop: context.eventLoop) |
| 125 | + self.scheduled = context.eventLoop.scheduleTask(in: self.delay) { |
| 126 | + this.value.onResponseDelayCompleted(context: loopBoundContext.value) |
| 127 | + } |
| 128 | + } |
| 129 | + case .body, .end: |
| 130 | + return |
| 131 | + } |
| 132 | + } |
| 133 | + |
| 134 | + private func onResponseDelayCompleted(context: ChannelHandlerContext) { |
| 135 | + guard case .delayingResponse = self.phase else { |
| 136 | + return |
| 137 | + } |
| 138 | + |
| 139 | + var head = HTTPResponse(status: self.code) |
| 140 | + |
| 141 | + // If the length isn't too big, let's include a content length header |
| 142 | + if case (let contentLength, false) = self.size.multipliedReportingOverflow(by: self.count) { |
| 143 | + head.headerFields = HTTPFields(dictionaryLiteral: (.contentLength, "\(contentLength)")) |
| 144 | + } |
| 145 | + |
| 146 | + context.writeAndFlush(self.wrapOutboundOut(.head(head)), promise: nil) |
| 147 | + self.phase = .dripping( |
| 148 | + DrippingState( |
| 149 | + chunksLeft: self.count, |
| 150 | + currentChunkBytesLeft: self.size |
| 151 | + ) |
| 152 | + ) |
| 153 | + |
| 154 | + self.writeChunk(context: context) |
| 155 | + } |
| 156 | + |
| 157 | + public func channelInactive(context: ChannelHandlerContext) { |
| 158 | + self.phase = .done |
| 159 | + self.scheduled?.cancel() |
| 160 | + context.fireChannelInactive() |
| 161 | + } |
| 162 | + |
| 163 | + public func channelWritabilityChanged(context: ChannelHandlerContext) { |
| 164 | + if case .dripping = self.phase, self.pendingWrite && context.channel.isWritable { |
| 165 | + self.writeChunk(context: context) |
| 166 | + } |
| 167 | + } |
| 168 | + |
| 169 | + private func writeChunk(context: ChannelHandlerContext) { |
| 170 | + // Make sure we don't accidentally reenter |
| 171 | + if self.activelyWritingChunk { |
| 172 | + return |
| 173 | + } |
| 174 | + self.activelyWritingChunk = true |
| 175 | + defer { |
| 176 | + self.activelyWritingChunk = false |
| 177 | + } |
| 178 | + |
| 179 | + // If we're not dripping the response body, there's nothing to do |
| 180 | + guard case .dripping(var drippingState) = self.phase else { |
| 181 | + return |
| 182 | + } |
| 183 | + |
| 184 | + // If we've sent all chunks, send end and be done |
| 185 | + if drippingState.chunksLeft < 1 { |
| 186 | + self.phase = .done |
| 187 | + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) |
| 188 | + return |
| 189 | + } |
| 190 | + |
| 191 | + var dataWritten = false |
| 192 | + while drippingState.currentChunkBytesLeft > 0, context.channel.isWritable { |
| 193 | + let toSend = min( |
| 194 | + drippingState.currentChunkBytesLeft, |
| 195 | + HTTPDrippingDownloadHandler.downloadBodyChunk.readableBytes |
| 196 | + ) |
| 197 | + let buffer = HTTPDrippingDownloadHandler.downloadBodyChunk.getSlice( |
| 198 | + at: HTTPDrippingDownloadHandler.downloadBodyChunk.readerIndex, |
| 199 | + length: toSend |
| 200 | + )! |
| 201 | + context.write(self.wrapOutboundOut(.body(buffer)), promise: nil) |
| 202 | + drippingState.currentChunkBytesLeft -= toSend |
| 203 | + dataWritten = true |
| 204 | + } |
| 205 | + |
| 206 | + // If we weren't able to send the full chunk, it's because the channel isn't writable. We yield until it is |
| 207 | + if drippingState.currentChunkBytesLeft > 0 { |
| 208 | + self.pendingWrite = true |
| 209 | + self.phase = .dripping(drippingState) |
| 210 | + if dataWritten { |
| 211 | + context.flush() |
| 212 | + } |
| 213 | + return |
| 214 | + } |
| 215 | + |
| 216 | + // We sent the full chunk. If we have no more chunks to write, we're done! |
| 217 | + drippingState.chunksLeft -= 1 |
| 218 | + if drippingState.chunksLeft == 0 { |
| 219 | + self.phase = .done |
| 220 | + context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) |
| 221 | + return |
| 222 | + } |
| 223 | + |
| 224 | + if dataWritten { |
| 225 | + context.flush() |
| 226 | + } |
| 227 | + |
| 228 | + // More chunks to write.. Kick off timer |
| 229 | + drippingState.currentChunkBytesLeft = self.size |
| 230 | + self.phase = .dripping(drippingState) |
| 231 | + if self.scheduledCallbackHandler == nil { |
| 232 | + let this = NIOLoopBound(self, eventLoop: context.eventLoop) |
| 233 | + let loopBoundContext = NIOLoopBound(context, eventLoop: context.eventLoop) |
| 234 | + self.scheduledCallbackHandler = HTTPDrippingDownloadHandlerScheduledCallbackHandler( |
| 235 | + handler: this, |
| 236 | + context: loopBoundContext |
| 237 | + ) |
| 238 | + } |
| 239 | + // SAFTEY: scheduling the callback only potentially throws when invoked off eventloop |
| 240 | + do { |
| 241 | + try context.eventLoop.scheduleCallback(in: self.frequency, handler: self.scheduledCallbackHandler!) |
| 242 | + } catch { |
| 243 | + context.fireErrorCaught(error) |
| 244 | + } |
| 245 | + } |
| 246 | + |
| 247 | + private struct HTTPDrippingDownloadHandlerScheduledCallbackHandler: NIOScheduledCallbackHandler & Sendable { |
| 248 | + var handler: NIOLoopBound<HTTPDrippingDownloadHandler> |
| 249 | + var context: NIOLoopBound<ChannelHandlerContext> |
| 250 | + |
| 251 | + func handleScheduledCallback(eventLoop: some EventLoop) { |
| 252 | + self.handler.value.writeChunk(context: self.context.value) |
| 253 | + } |
| 254 | + } |
| 255 | +} |
0 commit comments