-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathanodb.py
executable file
·443 lines (396 loc) · 17.1 KB
/
anodb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
#
# This marvelous code is Public Domain.
#
import re
from typing import Any, Callable
import logging
import importlib
import functools as ft
import datetime as dt
import time
from collections import deque
import aiosql as sql # type: ignore
from aiosql.types import DriverAdapterProtocol, SQLOperationType as Ops
import json
log = logging.getLogger("anodb")
# get package version
from importlib.metadata import version as pkg_version
__version__ = pkg_version("anodb")
CacheFactory = Callable[[str, Callable], Callable]
"""Type for caching cachable queries."""
class AnoDBException(Exception):
"""Locally generated exception."""
pass
#
# DB (Database) class
#
class DB:
"""
Class to hide database connection and queries.
The class provides the DB-API 2.0 connection methods,
and wrap SQL execution methods from aiosql.
Constructor:
- :param db: database engine/driver.
- :param conn: database-specific simple connection string.
- :param queries: file(s) holding queries for `aiosql`, may be empty.
- :param conn_args: database-specific connection options as a list.
- :param conn_kwargs: database-specific connection options as a dict.
- :param adapter_args: adapter creation options as a list.
- :param adapter_kwargs: adapter creation options as a dict.
- :param auto_reconnect: whether to reconnect on connection errors, default is true.
- :param auto_rollback: whether to rollback internaly on errors, default is true.
- :param kwargs_only: whether to require named parameters on query execution, default is *true*.
- :param attribute: attribute dot access substitution, default is ``"__"``.
- :param exception: user function to reraise database exceptions.
- :param debug: debug mode, generate more logs through ``logging``.
- :param cacher: cache factory for queries marked as such.
- :param cached: doc string re for checking whether to cache a query, default is ``r"\\bCACHED\\b"``
- :param last_calls: keep track of this method invocations, default is *1*.
- :param **conn_options: database-specific ``kwargs`` connection options.
"""
# global counter to help identify DB objects
_counter = 0
# database connection driver and variants, with a little hardcoding
SQLITE = ("sqlite3", "sqlite")
POSTGRES = ("psycopg", "pg", "postgres", "postgresql", "psycopg3")
# others stay as-is: psycopg2 pg8000…
# connection delays
_CONNECTION_MIN_DELAY = 0.001
_CONNECTION_MAX_DELAY = 30.0
# select operations
_SELECT_OPS = (Ops.SELECT, Ops.SELECT_ONE, Ops.SELECT_VALUE)
def _log_info(self, m: str):
log.info(f"DB:{self._db}:{self._id} {m}")
def _log_debug(self, m: str):
log.debug(f"DB:{self._db}:{self._id} {m}")
def _log_warning(self, m: str):
log.warning(f"DB:{self._db}:{self._id} {m}")
def _log_error(self, m: str):
log.error(f"DB:{self._db}:{self._id} {m}")
def __init__(
self,
db: str,
conn: str|None,
queries: str|list[str] = [],
options: str|dict[str, Any] = {}, # undocumented backward compatibility
# further connection options
conn_args: list[str] = [],
conn_kwargs: dict[str, Any] = {},
# adapter creation options
adapter_args: list[Any] = [],
adapter_kwargs: dict[str, Any] = {},
# anodb behavior
auto_reconnect: bool = True,
auto_rollback: bool = True,
debug: bool = False,
exception: Callable[[BaseException], BaseException]|None = None,
cacher: CacheFactory|None = None,
cached: str = r"\bCACHED\b",
last_calls: int = 1,
# aiosql behavior
kwargs_only: bool = True,
attribute: str = "__",
# connection options
**conn_options,
):
DB._counter += 1
self._id = DB._counter
self.__version__ = __version__
self.__aiosql_version__ = pkg_version("aiosql")
# this is the class name
self._db = (
"sqlite3" if db in self.SQLITE else "psycopg" if db in self.POSTGRES else db
).lower()
assert self._db in sql.aiosql._ADAPTERS, f"database {db} is supported"
self._log_info("creating DB")
self._set_db_pkg()
# connection parameters…
self._conn = None
self._conn_args = [] if conn is None else [conn]
self._conn_args.extend(conn_args)
self._conn_kwargs: dict[str, Any] = dict(conn_kwargs)
self._adapter: DriverAdapterProtocol|None = None
# backward compatibility for "options"
if isinstance(options, str):
import ast
self._conn_kwargs.update(ast.literal_eval(options))
elif isinstance(options, dict):
self._conn_kwargs.update(options)
else:
raise AnoDBException(f"unexpected type for options: {type(options)}")
# remaining parameters are associated to the connection
self._conn_kwargs.update(conn_options)
# adapter
self._adapter_args = list(adapter_args)
self._adapter_kwargs = dict(adapter_kwargs)
# useful global stats
self._count: dict[str, int] = {} # name -> #calls
self._conn_last = None # current connection start
self._conn_count: int = 0 # how many connections succeeded
self._conn_total: int = 0 # number of "executions" in this connection
self._conn_ntx: int = 0 # number of tx in this connection
self._conn_nstat: int = 0 # number of executions (fn, cursors) in current tx
self._total: int = 0 # total number of executions
self._ntx: int = 0 # number of tx
# various boolean flags
self._debug = debug
if debug:
log.setLevel(logging.DEBUG)
self._log_info("running in debug mode…")
self._auto_reconnect = auto_reconnect
self._auto_rollback = auto_rollback
self._kwargs_only = kwargs_only
self._reconn = False
# other parameters
self._attribute = attribute
self._exception = exception
self._cacher = cacher
self._cached = cached
self._last_calls = last_calls
self._calls = deque()
# queries… keep track of calls
self._queries_file = [queries] if isinstance(queries, str) else queries
self._queries: list[sql.aiosql.Queries] = [] # type: ignore
self._available_queries: set[str] = set()
for fn in self._queries_file:
self.add_queries_from_path(fn)
# last thing is to actually create the connection, which may fail
self._conn_delay: float|None = None # seconds
self._conn_last_fail: dt.datetime|None = None
self._conn_attempts: int = 0
self._conn_failures: int = 0
self._do_connect()
def _possibly_reconnect(self):
"""Detect a connection error for psycopg."""
# FIXME detect other cases of bad connections?
self._reconn = self._auto_reconnect and (
(self._db == "psycopg" and hasattr(self._conn, "closed") and self._conn.closed) or # type: ignore
(self._db == "psycopg2" and hasattr(self._conn, "closed") and self._conn.closed == 2)) # type: ignore
def _call_fn(self, _query, _fn, *args, **kwargs):
"""Forward method call to aiosql query.
On connection failure, it will try to reconnect on the next call
if auto_reconnect was set.
This may or may not be a good idea, but it should be: the failure
raises an exception which should abort the current request, so that
the next call should be on a different request.
"""
_ = self._debug and self._log_debug(f"{_query}({args}, {kwargs})")
if self._reconn and self._auto_reconnect:
self._reconnect()
self._conn_nstat += 1
self._count[_query] += 1
if self._last_calls:
self._calls.append(_query)
while len(self._calls) > self._last_calls:
self._calls.popleft()
try:
return _fn(self._conn, *args, **kwargs)
except self._db_error as error:
self._log_info(f"query {_query} failed: {error}")
if self._auto_rollback:
try:
if self._conn:
self._conn.rollback()
except self._db_error as rolerr:
self._log_warning(f"rollback failed: {rolerr}")
self._possibly_reconnect()
# re-raise error
raise self._exception(error) if self._exception else error
except Exception as e: # pragma: no cover
self._log_error(f"unexpected exception: {e}")
raise
def _create_fn(self, q: str, f: Callable) -> Callable:
"""Create one wrapped method."""
# call internal caller
@ft.wraps(f)
def fn(*a, **kw):
return self._call_fn(q, f, *a, **kw)
# FIXME cachability may not work on some types? lo?
# NOTE we skip internal *_cursor attributes
if self._cacher and not q.endswith("_cursor") and f.__doc__ and re.search(self._cached, f.__doc__):
operation = f.operation # type: ignore
if operation not in self._SELECT_OPS:
self._log_warning(f"skip caching non select method: {q} ({operation})")
return fn
# else proceed with wrapping
self._log_debug(f"caching query {q}")
if operation == Ops.SELECT: # materialize generator
@ft.wraps(fn)
def fx(*a, **kw):
return list(fn(*a, **kw))
else:
fx = fn
return self._cacher(q, fx)
else:
return fn
# this could probably be done dynamically by overriding __getattribute__
def _create_fns(self, queries: sql.aiosql.Queries): # type: ignore
"""Create call forwarding to insert the database connection."""
# keep first encountered adapter
if self._adapter is None:
self._adapter = queries.driver_adapter
self._queries.append(queries)
for q in queries.available_queries:
f = getattr(queries, q)
if callable(f):
self._log_debug(f"adding q={q}")
if hasattr(self, q):
raise AnoDBException(f"cannot override existing method: {q}")
setattr(self, q, self._create_fn(q, f))
self._available_queries.add(q)
self._count[q] = 0
def add_queries_from_path(self, fn: str):
"""Load queries from a file or directory."""
self._create_fns(sql.from_path(fn, self._db, *self._adapter_args,
kwargs_only=self._kwargs_only, attribute=self._attribute,
**self._adapter_kwargs))
def add_queries_from_str(self, qs: str):
"""Load queries from a string."""
self._create_fns(sql.from_str(qs, self._db, *self._adapter_args,
kwargs_only=self._kwargs_only, attribute=self._attribute,
**self._adapter_kwargs))
def _set_db_pkg(self):
"""Load database package."""
package, module = self._db, self._db
# skip apsw as DB API support is really partial?
if self._db == "pygresql":
package, module = "pgdb", "pgdb"
elif self._db in ("MySQLdb", "mysqldb"): # pragma: no cover
package, module = "MySQLdb", "mysqlclient"
elif self._db in ("mysql-connector", "mysql.connector"): # pragma: no cover
package, module = "mysql.connector", "mysql.connector"
else:
pass
# record db package
try:
self._db_pkg = importlib.import_module(package)
except ImportError: # pragma: no cover
self._log_error(f"cannot import {package} for {self._db}")
raise
# best effort only
try:
self._db_version = pkg_version(module)
except Exception:
self._db_version = "<unknown>"
# get exception class
self._db_error = self._db_pkg.Error if hasattr(self._db_pkg, "Error") else Exception
# myco does not need to follow the standard?
if self._db_error is Exception: # pragma: no cover
self._log_error(f"missing Error class in {package}, falling back to Exception")
def __connect(self):
"""Create a database connection (internal)."""
self._log_info(f"{self._db}: connecting")
# PEP249 does not impose a unified signature for connect.
return self._db_pkg.connect(*self._conn_args, **self._conn_kwargs)
def _do_connect(self):
"""Create a connection, possibly with active throttling."""
self._conn = None
try:
self._conn_attempts += 1
if self._conn_delay is not None:
delay = dt.timedelta(seconds=self._conn_delay)
assert self._conn_last_fail
wait = (delay - (dt.datetime.now(dt.timezone.utc) - self._conn_last_fail)).total_seconds()
if wait > 0.0:
self._log_info(f"connection wait #{self._conn_attempts}: {wait}")
time.sleep(wait)
self._conn = self.__connect()
# on success, update stats
self._ntx += self._conn_ntx
self._total += self._conn_total
self._conn_count += 1
self._conn_last = dt.datetime.now(dt.timezone.utc)
self._conn_total = 0
self._conn_ntx = 0
# on success, reset reconnection stuff
self._conn_attempts = 0
self._conn_last_fail = None
self._conn_delay = None
except self._db_error as e:
self._conn_failures += 1
self._conn_last = None
self._conn_last_fail = dt.datetime.now(dt.timezone.utc)
if self._conn_delay is None:
self._conn_delay = self._CONNECTION_MIN_DELAY
else:
self._conn_delay = min(2 * self._conn_delay, self._CONNECTION_MAX_DELAY)
self._log_error(f"connect failed #{self._conn_attempts}: {e}")
raise e
def _reconnect(self):
"""Try to reconnect to database, possibly with some cleanup."""
self._log_info(f"{self._db}: reconnecting")
if self._conn:
# attempt at closing but ignore errors
try:
self._conn.close()
except self._db_error as error: # pragma: no cover
self._log_error(f"DB {self._db} close: {error}")
self._do_connect()
self._reconn = False
def connect(self):
"""Create (if needed) and return the database connection."""
if "_conn" not in self.__dict__ or not self._conn:
self._do_connect()
return self._conn
def cursor(self):
"""Get a cursor on the current connection."""
if self._reconn and self._auto_reconnect:
self._reconnect()
assert self._conn is not None and self._adapter is not None
self._conn_nstat += 1
if hasattr(self._adapter, "_cursor"):
return self._adapter._cursor(self._conn) # type: ignore
else: # pragma: no cover
return self._conn.cursor()
def commit(self):
"""Commit database transaction."""
assert self._conn is not None
self._conn_ntx += 1
self._conn_total += self._conn_nstat
self._conn_nstat = 0
self._conn.commit()
def rollback(self):
"""Rollback database transaction."""
assert self._conn is not None
self._conn_ntx += 1
self._conn_total += self._conn_nstat
self._conn_nstat = 0
self._conn.rollback()
def close(self):
"""Close underlying database connection if needed."""
if self._conn is not None:
# NOTE only reset if close succeeded?
self._conn.close()
self._conn = None
# should we try to reconnect?
self._reconn = self._auto_reconnect
def _stats(self):
"""Generate a JSON-compatible structure for statistics."""
return {
"id": self._id,
"driver": self._db,
"info": self._conn_args, # _conn_kwargs?
"conn": {
# current connection status
"nstat": self._conn_nstat,
"total": self._conn_total,
"ntx": self._conn_ntx,
"last": self._conn_last.isoformat() if self._conn_last else None,
# (re)connection attempt status
"attempts": self._conn_attempts,
"failures": self._conn_failures,
"delay": self._conn_delay,
"last-fail": self._conn_last_fail.isoformat() if self._conn_last_fail else None,
},
# life time
"total": self._total,
"ntx": self._ntx,
"calls": self._count,
"count": self._conn_count,
"lasts": list(self._calls),
}
def __str__(self):
return json.dumps(self._stats())
def __del__(self):
if hasattr(self, "_conn") and self._conn:
self.close()