|
| 1 | +import Foundation |
| 2 | + |
| 3 | +public enum Errors: Error { |
| 4 | + case connectionError(String) |
| 5 | +} |
| 6 | + |
| 7 | +public enum ReadyState: Int { |
| 8 | + case connecting = 0 |
| 9 | + case open = 1 |
| 10 | + case closing = 2 |
| 11 | + case closed = 3 |
| 12 | +} |
| 13 | + |
| 14 | +// sora gay |
| 15 | +// let's keep this lmao |
| 16 | +// YEP |
| 17 | + |
| 18 | +/* API Usage |
| 19 | + var socket = WebSocketStream("wss://gateway.discord.gg/?v=9&encoding=json") |
| 20 | + |
| 21 | + socket.closed { code, reason in |
| 22 | + ... |
| 23 | + } |
| 24 | + |
| 25 | + socket.error { error in |
| 26 | + ... |
| 27 | + } |
| 28 | + |
| 29 | + try! await socket.ready() |
| 30 | + |
| 31 | + for await msg in socket { |
| 32 | + ... |
| 33 | + } |
| 34 | + */ |
| 35 | + |
| 36 | +private class SessionDelegate: NSObject, URLSessionWebSocketDelegate { |
| 37 | + private let didOpen: (String?) -> Void |
| 38 | + private let didClose: (Int, Data?) -> Void |
| 39 | + |
| 40 | + init(didOpen: @escaping (String?) -> Void, didClose: @escaping (Int, Data?) -> Void) { |
| 41 | + self.didOpen = didOpen |
| 42 | + self.didClose = didClose |
| 43 | + } |
| 44 | + |
| 45 | + func urlSession( |
| 46 | + _: URLSession, |
| 47 | + webSocketTask _: URLSessionWebSocketTask, |
| 48 | + didOpenWithProtocol: String? |
| 49 | + ) { |
| 50 | + didOpen(didOpenWithProtocol) |
| 51 | + } |
| 52 | + |
| 53 | + func urlSession( |
| 54 | + _: URLSession, |
| 55 | + webSocketTask _: URLSessionWebSocketTask, |
| 56 | + didCloseWith: URLSessionWebSocketTask.CloseCode, |
| 57 | + reason: Data? |
| 58 | + ) { |
| 59 | + didClose(didCloseWith.rawValue, reason) |
| 60 | + } |
| 61 | +} |
| 62 | + |
| 63 | +public typealias ClosedHandler = (Int, String?) -> Void |
| 64 | +public typealias ErrorHandler = (Error) -> Void |
| 65 | + |
| 66 | +public protocol WSMessage {} |
| 67 | +extension String: WSMessage {} |
| 68 | +extension Data: WSMessage {} |
| 69 | + |
| 70 | +public class WebSocketStream: AsyncSequence { |
| 71 | + private var messageHandlers: [(WSMessage) -> Void] = [] |
| 72 | + |
| 73 | + public class AsyncIterator: AsyncIteratorProtocol { |
| 74 | + var websocketStream: WebSocketStream |
| 75 | + private var _nextHandler: ((WSMessage) -> Void)? |
| 76 | + private var messageCache: [WSMessage] = [] |
| 77 | + var nextHandler: ((WSMessage) -> Void)? { |
| 78 | + get { |
| 79 | + self._nextHandler |
| 80 | + } |
| 81 | + set(newHandler) { |
| 82 | + self._nextHandler = newHandler |
| 83 | + if let newHandler = newHandler { |
| 84 | + for msg in messageCache { |
| 85 | + newHandler(msg) |
| 86 | + } |
| 87 | + } |
| 88 | + } |
| 89 | + } |
| 90 | + |
| 91 | + public func next() async -> WSMessage? { |
| 92 | + let msg: WSMessage? = await withCheckedContinuation { ctx in |
| 93 | + if self.websocketStream.readyState == .closed { |
| 94 | + ctx.resume(returning: nil) |
| 95 | + } else { |
| 96 | + nextHandler = { msg in |
| 97 | + ctx.resume(returning: msg) |
| 98 | + } |
| 99 | + } |
| 100 | + } |
| 101 | + return msg |
| 102 | + } |
| 103 | + |
| 104 | + public init(websocketStream: WebSocketStream) { |
| 105 | + self.websocketStream = websocketStream |
| 106 | + self.websocketStream.message(handler) |
| 107 | + } |
| 108 | + |
| 109 | + private func handler(msg: WSMessage) { |
| 110 | + if let nextHandler = nextHandler { |
| 111 | + nextHandler(msg) |
| 112 | + } else { |
| 113 | + messageCache.append(msg) |
| 114 | + } |
| 115 | + } |
| 116 | + } |
| 117 | + |
| 118 | + public func makeAsyncIterator() -> AsyncIterator { |
| 119 | + return AsyncIterator(websocketStream: self) |
| 120 | + } |
| 121 | + |
| 122 | + public typealias Element = WSMessage |
| 123 | + |
| 124 | + public let url: URL |
| 125 | + private(set) var readyState = ReadyState.connecting |
| 126 | + private(set) var subProtocol: String? |
| 127 | + private var wsTask: URLSessionWebSocketTask? |
| 128 | + public var maximumMessageSize: Int { |
| 129 | + get { |
| 130 | + return wsTask?.maximumMessageSize ?? 0 |
| 131 | + } |
| 132 | + set(value) { |
| 133 | + if let wsTask = self.wsTask { |
| 134 | + wsTask.maximumMessageSize = value |
| 135 | + } |
| 136 | + } |
| 137 | + } |
| 138 | + |
| 139 | + private var readyHandles: [() -> Void] = [] |
| 140 | + private var closedHandlers: [ClosedHandler] = [] |
| 141 | + private var errorHandlers: [ErrorHandler] = [] |
| 142 | + |
| 143 | + private var messageCache: [WSMessage] = [] |
| 144 | + |
| 145 | + /// Wait for the connection to get ready. |
| 146 | + public func ready() async { |
| 147 | + handleMessage() |
| 148 | + wsTask?.resume() |
| 149 | + let _: () = await withCheckedContinuation { ctx in |
| 150 | + if self.readyState == .open { |
| 151 | + ctx.resume() |
| 152 | + } else { |
| 153 | + readyHandles.append { |
| 154 | + ctx.resume() |
| 155 | + } |
| 156 | + } |
| 157 | + } |
| 158 | + } |
| 159 | + |
| 160 | + init(url: URL, protocols: [String] = [], headers: [String: String] = [:]) { |
| 161 | + self.url = url |
| 162 | + let session = URLSession( |
| 163 | + configuration: URLSessionConfiguration.default, |
| 164 | + delegate: SessionDelegate( |
| 165 | + didOpen: { [weak self] proto in |
| 166 | + guard let self = self else { return } |
| 167 | + self.subProtocol = proto |
| 168 | + self.readyState = .open |
| 169 | + |
| 170 | + for handle in self.readyHandles { |
| 171 | + handle() |
| 172 | + } |
| 173 | + }, |
| 174 | + didClose: { [weak self] closedCode, reason in |
| 175 | + guard let self = self else { return } |
| 176 | + self.readyState = .closed |
| 177 | + for handler in self.closedHandlers { |
| 178 | + handler(closedCode, String(data: reason ?? Data(), encoding: .utf8)) |
| 179 | + } |
| 180 | + } |
| 181 | + ), |
| 182 | + delegateQueue: nil |
| 183 | + ) |
| 184 | + var request = URLRequest(url: url) |
| 185 | + request.allHTTPHeaderFields = [:] |
| 186 | + if protocols.count > 0 { |
| 187 | + request.allHTTPHeaderFields!["Sec-WebSocket-Protocol"] = protocols.joined(separator: ",") |
| 188 | + } |
| 189 | + for header in headers { |
| 190 | + request.allHTTPHeaderFields![header.key] = header.value |
| 191 | + } |
| 192 | + |
| 193 | + wsTask = session.webSocketTask(with: request) |
| 194 | + } |
| 195 | + |
| 196 | + /// Send binary data to the Web Socket |
| 197 | + public func send(_ data: Data) async throws { |
| 198 | + try await wsTask?.send(.data(data)) |
| 199 | + } |
| 200 | + |
| 201 | + /// Send text data to the Web Socket |
| 202 | + public func send(_ string: String) async throws { |
| 203 | + try await wsTask?.send(.string(string)) |
| 204 | + } |
| 205 | + |
| 206 | + public typealias CloseCode = URLSessionWebSocketTask.CloseCode |
| 207 | + |
| 208 | + public func close(code: CloseCode, reason: String) { |
| 209 | + close(code: code, reason: reason.data(using: .utf8)!) |
| 210 | + } |
| 211 | + |
| 212 | + public func close(code: CloseCode, reason: Data) { |
| 213 | + wsTask?.cancel(with: code, reason: reason) |
| 214 | + } |
| 215 | + |
| 216 | + public func close(code: CloseCode) { |
| 217 | + wsTask?.cancel(with: code, reason: nil) |
| 218 | + } |
| 219 | + |
| 220 | + private func handleMessage() { |
| 221 | + wsTask?.receive { [weak self] result in |
| 222 | + guard let self = self else { return } |
| 223 | + |
| 224 | + switch result { // fak can't find a good swift ws server library |
| 225 | + case let .failure(error): |
| 226 | + for handler in self.errorHandlers { |
| 227 | + handler(error) |
| 228 | + } |
| 229 | + case let .success(msg): |
| 230 | + // oh fuck |
| 231 | + // should we use a message queue then hmm |
| 232 | + // or handleMessage() after some async iter is made |
| 233 | + switch msg { |
| 234 | + case let .data(data): |
| 235 | + if self.messageHandlers.count == 0 { |
| 236 | + self.messageCache.append(data) |
| 237 | + } else { |
| 238 | + for handler in self.messageHandlers { |
| 239 | + handler(data) |
| 240 | + } |
| 241 | + } |
| 242 | + case let .string(str): |
| 243 | + if self.messageHandlers.count == 0 { |
| 244 | + self.messageCache.append(str) |
| 245 | + } else { |
| 246 | + for handler in self.messageHandlers { |
| 247 | + handler(str) |
| 248 | + } |
| 249 | + } |
| 250 | + @unknown default: break |
| 251 | + } |
| 252 | + } |
| 253 | + |
| 254 | + self.handleMessage() |
| 255 | + } |
| 256 | + } |
| 257 | + |
| 258 | + func closed(_ handler: @escaping ClosedHandler) { |
| 259 | + closedHandlers.append(handler) |
| 260 | + } |
| 261 | + |
| 262 | + func error(_ handler: @escaping ErrorHandler) { |
| 263 | + errorHandlers.append(handler) |
| 264 | + } |
| 265 | + |
| 266 | + func message(_ handler: @escaping (WSMessage) -> Void) { |
| 267 | + if messageHandlers.count == 0 { |
| 268 | + for msg in messageCache { |
| 269 | + handler(msg) |
| 270 | + } |
| 271 | + } |
| 272 | + messageHandlers.append(handler) |
| 273 | + } |
| 274 | +} |
0 commit comments