|
| 1 | +# python3 |
| 2 | + |
| 3 | +import collections |
| 4 | +import errno |
| 5 | +import heapq |
| 6 | +import json |
| 7 | +import random |
| 8 | +import selectors |
| 9 | +import socket as _socket |
| 10 | +import sys |
| 11 | +import time |
| 12 | + |
| 13 | + |
| 14 | +class EventLoop: |
| 15 | + def __init__(self): |
| 16 | + self._queue = Queue() |
| 17 | + self._time = None |
| 18 | + |
| 19 | + def run(self, entry_point, *args): |
| 20 | + self._execute(entry_point, *args) |
| 21 | + |
| 22 | + while not self._queue.is_empty(): |
| 23 | + fn, mask = self._queue.pop(self._time) |
| 24 | + self._execute(fn, mask) |
| 25 | + |
| 26 | + self._queue.close() |
| 27 | + |
| 28 | + def register_fileobj(self, fileobj, callback): |
| 29 | + self._queue.register_fileobj(fileobj, callback) |
| 30 | + |
| 31 | + def unregister_fileobj(self, fileobj): |
| 32 | + self._queue.unregister_fileobj(fileobj) |
| 33 | + |
| 34 | + def set_timer(self, duration, callback): |
| 35 | + self._time = hrtime() |
| 36 | + self._queue.register_timer(self._time + duration, |
| 37 | + lambda _: callback()) |
| 38 | + |
| 39 | + def _execute(self, callback, *args): |
| 40 | + self._time = hrtime() |
| 41 | + try: |
| 42 | + callback(*args) # new callstack starts |
| 43 | + except Exception as err: |
| 44 | + print('Uncaught exception:', err) |
| 45 | + self._time = hrtime() |
| 46 | + |
| 47 | + |
| 48 | +class Queue: |
| 49 | + def __init__(self): |
| 50 | + self._selector = selectors.DefaultSelector() |
| 51 | + self._timers = [] |
| 52 | + self._timer_no = 0 |
| 53 | + self._ready = collections.deque() |
| 54 | + |
| 55 | + def register_timer(self, tick, callback): |
| 56 | + timer = (tick, self._timer_no, callback) |
| 57 | + heapq.heappush(self._timers, timer) |
| 58 | + self._timer_no += 1 |
| 59 | + |
| 60 | + def register_fileobj(self, fileobj, callback): |
| 61 | + self._selector.register(fileobj, |
| 62 | + selectors.EVENT_READ | selectors.EVENT_WRITE, |
| 63 | + callback) |
| 64 | + |
| 65 | + def unregister_fileobj(self, fileobj): |
| 66 | + self._selector.unregister(fileobj) |
| 67 | + |
| 68 | + def pop(self, tick): |
| 69 | + if self._ready: |
| 70 | + return self._ready.popleft() |
| 71 | + |
| 72 | + timeout = None |
| 73 | + if self._timers: |
| 74 | + timeout = (self._timers[0][0] - tick) / 10e6 |
| 75 | + |
| 76 | + events = self._selector.select(timeout) |
| 77 | + for key, mask in events: |
| 78 | + callback = key.data |
| 79 | + self._ready.append((callback, mask)) |
| 80 | + |
| 81 | + if not self._ready and self._timers: |
| 82 | + idle = (self._timers[0][0] - tick) |
| 83 | + if idle > 0: |
| 84 | + time.sleep(idle / 10e6) |
| 85 | + return self.pop(tick + idle) |
| 86 | + |
| 87 | + while self._timers and self._timers[0][0] <= tick: |
| 88 | + _, _, callback = heapq.heappop(self._timers) |
| 89 | + self._ready.append((callback, None)) |
| 90 | + |
| 91 | + return self._ready.popleft() |
| 92 | + |
| 93 | + def is_empty(self): |
| 94 | + return not (self._ready or self._timers or self._selector.get_map()) |
| 95 | + |
| 96 | + def close(self): |
| 97 | + self._selector.close() |
| 98 | + |
| 99 | + |
| 100 | +class Context: |
| 101 | + _event_loop = None |
| 102 | + |
| 103 | + @classmethod |
| 104 | + def set_event_loop(cls, event_loop): |
| 105 | + cls._event_loop = event_loop |
| 106 | + |
| 107 | + @property |
| 108 | + def evloop(self): |
| 109 | + return self._event_loop |
| 110 | + |
| 111 | + |
| 112 | +class IOError(Exception): |
| 113 | + def __init__(self, message, errorno, errorcode): |
| 114 | + super().__init__(message) |
| 115 | + self.errorno = errorno |
| 116 | + self.errorcode = errorcode |
| 117 | + |
| 118 | + def __str__(self): |
| 119 | + return super().__str__() + f' (error {self.errorno} {self.errorcode})' |
| 120 | + |
| 121 | + |
| 122 | +def hrtime(): |
| 123 | + """ returns time in microseconds """ |
| 124 | + return int(time.time() * 10e6) |
| 125 | + |
| 126 | + |
| 127 | +class set_timer(Context): |
| 128 | + def __init__(self, duration, callback): |
| 129 | + """ duration is in microseconds """ |
| 130 | + self.evloop.set_timer(duration, callback) |
| 131 | + |
| 132 | + |
| 133 | +class socket(Context): |
| 134 | + def __init__(self, *args): |
| 135 | + self._sock = _socket.socket(*args) |
| 136 | + self._sock.setblocking(False) |
| 137 | + self.evloop.register_fileobj(self._sock, self._on_event) |
| 138 | + # 0 - initial |
| 139 | + # 1 - connecting |
| 140 | + # 2 - connected |
| 141 | + # 3 - closed |
| 142 | + self._state = 0 |
| 143 | + self._callbacks = {} |
| 144 | + |
| 145 | + def connect(self, addr, callback): |
| 146 | + assert self._state == 0 |
| 147 | + self._state = 1 |
| 148 | + self._callbacks['conn'] = callback |
| 149 | + err = self._sock.connect_ex(addr) |
| 150 | + assert errno.errorcode[err] == 'EINPROGRESS' |
| 151 | + |
| 152 | + def recv(self, n, callback): |
| 153 | + assert self._state == 2 |
| 154 | + assert 'recv' not in self._callbacks |
| 155 | + |
| 156 | + def _on_read_ready(err): |
| 157 | + if err: |
| 158 | + return callback(err) |
| 159 | + data = self._sock.recv(n) |
| 160 | + callback(None, data) |
| 161 | + |
| 162 | + self._callbacks['recv'] = _on_read_ready |
| 163 | + |
| 164 | + def sendall(self, data, callback): |
| 165 | + assert self._state == 2 |
| 166 | + assert 'sent' not in self._callbacks |
| 167 | + |
| 168 | + def _on_write_ready(err): |
| 169 | + nonlocal data |
| 170 | + if err: |
| 171 | + return callback(err) |
| 172 | + |
| 173 | + n = self._sock.send(data) |
| 174 | + if n < len(data): |
| 175 | + data = data[n:] |
| 176 | + self._callbacks['sent'] = _on_write_ready |
| 177 | + else: |
| 178 | + callback(None) |
| 179 | + |
| 180 | + self._callbacks['sent'] = _on_write_ready |
| 181 | + |
| 182 | + def close(self): |
| 183 | + self.evloop.unregister_fileobj(self._sock) |
| 184 | + self._callbacks.clear() |
| 185 | + self._state = 3 |
| 186 | + self._sock.close() |
| 187 | + |
| 188 | + def _on_event(self, mask): |
| 189 | + if self._state == 1: |
| 190 | + assert mask == selectors.EVENT_WRITE |
| 191 | + cb = self._callbacks.pop('conn') |
| 192 | + err = self._get_sock_error() |
| 193 | + if err: |
| 194 | + self.close() |
| 195 | + else: |
| 196 | + self._state = 2 |
| 197 | + cb(err) |
| 198 | + |
| 199 | + if mask & selectors.EVENT_READ: |
| 200 | + cb = self._callbacks.get('recv') |
| 201 | + if cb: |
| 202 | + del self._callbacks['recv'] |
| 203 | + err = self._get_sock_error() |
| 204 | + cb(err) |
| 205 | + |
| 206 | + if mask & selectors.EVENT_WRITE: |
| 207 | + cb = self._callbacks.get('sent') |
| 208 | + if cb: |
| 209 | + del self._callbacks['sent'] |
| 210 | + err = self._get_sock_error() |
| 211 | + cb(err) |
| 212 | + |
| 213 | + def _get_sock_error(self): |
| 214 | + err = self._sock.getsockopt(_socket.SOL_SOCKET, |
| 215 | + _socket.SO_ERROR) |
| 216 | + if not err: |
| 217 | + return None |
| 218 | + return IOError('connection failed', |
| 219 | + err, errno.errorcode[err]) |
| 220 | + |
| 221 | +############################################################################### |
| 222 | + |
| 223 | +class Client: |
| 224 | + def __init__(self, addr): |
| 225 | + self.addr = addr |
| 226 | + |
| 227 | + def get_user(self, user_id, callback): |
| 228 | + self._get(f'GET user {user_id}\n', callback) |
| 229 | + |
| 230 | + def get_balance(self, account_id, callback): |
| 231 | + self._get(f'GET account {account_id}\n', callback) |
| 232 | + |
| 233 | + def _get(self, req, callback): |
| 234 | + sock = socket(_socket.AF_INET, _socket.SOCK_STREAM) |
| 235 | + |
| 236 | + def _on_conn(err): |
| 237 | + if err: |
| 238 | + return callback(err) |
| 239 | + |
| 240 | + def _on_sent(err): |
| 241 | + if err: |
| 242 | + sock.close() |
| 243 | + return callback(err) |
| 244 | + |
| 245 | + def _on_resp(err, resp=None): |
| 246 | + sock.close() |
| 247 | + if err: |
| 248 | + return callback(err) |
| 249 | + callback(None, json.loads(resp)) |
| 250 | + |
| 251 | + sock.recv(1024, _on_resp) |
| 252 | + |
| 253 | + sock.sendall(req.encode('utf8'), _on_sent) |
| 254 | + |
| 255 | + sock.connect(self.addr, _on_conn) |
| 256 | + |
| 257 | + |
| 258 | +def get_user_balance(serv_addr, user_id, done): |
| 259 | + client = Client(serv_addr) |
| 260 | + |
| 261 | + def on_timer(): |
| 262 | + |
| 263 | + def on_user(err, user=None): |
| 264 | + if err: |
| 265 | + return done(err) |
| 266 | + |
| 267 | + def on_account(err, acc=None): |
| 268 | + if err: |
| 269 | + return done(err) |
| 270 | + done(None, f'User {user["name"]} has {acc["balance"]} USD') |
| 271 | + |
| 272 | + if user_id % 5 == 0: |
| 273 | + raise Exception('Do not throw from callbacks') |
| 274 | + client.get_balance(user['account_id'], on_account) |
| 275 | + |
| 276 | + client.get_user(user_id, on_user) |
| 277 | + |
| 278 | + set_timer(random.randint(0, 10e6), on_timer) |
| 279 | + |
| 280 | + |
| 281 | +def main(serv_addr): |
| 282 | + def on_balance(err, balance=None): |
| 283 | + if err: |
| 284 | + print('ERROR', err) |
| 285 | + else: |
| 286 | + print(balance) |
| 287 | + |
| 288 | + for i in range(10): |
| 289 | + get_user_balance(serv_addr, i, on_balance) |
| 290 | + |
| 291 | + |
| 292 | +if __name__ == '__main__': |
| 293 | + event_loop = EventLoop() |
| 294 | + Context.set_event_loop(event_loop) |
| 295 | + |
| 296 | + serv_addr = ('127.0.0.1', int(sys.argv[1])) |
| 297 | + event_loop.run(main, serv_addr) |
| 298 | + |
0 commit comments