|
1 | 1 | import threading
|
2 | 2 | from functools import wraps
|
3 | 3 | from uuid import uuid4
|
| 4 | +from contextvars import ContextVar |
4 | 5 |
|
5 | 6 | from django.db import connections
|
6 | 7 |
|
7 | 8 | THREAD_LOCAL = threading.local()
|
8 | 9 |
|
| 10 | +DB_FOR_READ_OVERRIDE = ContextVar('DB_FOR_READ_OVERRIDE', default='default') |
| 11 | +DB_FOR_WRITE_OVERRIDE = ContextVar('DB_FOR_WRITE_OVERRIDE', default='default') |
| 12 | + |
9 | 13 |
|
10 | 14 | class DynamicDbRouter(object):
|
11 | 15 | """A router that decides what db to read from based on a variable
|
12 | 16 | local to the current thread.
|
13 | 17 | """
|
14 |
| - |
| 18 | + |
15 | 19 | def db_for_read(self, model, **hints):
|
16 |
| - return getattr(THREAD_LOCAL, 'DB_FOR_READ_OVERRIDE', ['default'])[-1] |
17 |
| - |
| 20 | + return DB_FOR_READ_OVERRIDE.get() |
| 21 | + # return getattr(THREAD_LOCAL, 'DB_FOR_READ_OVERRIDE', ['default'])[-1] |
| 22 | + |
18 | 23 | def db_for_write(self, model, **hints):
|
19 |
| - return getattr(THREAD_LOCAL, 'DB_FOR_WRITE_OVERRIDE', ['default'])[-1] |
20 |
| - |
| 24 | + return DB_FOR_WRITE_OVERRIDE.get() |
| 25 | + # return getattr(THREAD_LOCAL, 'DB_FOR_WRITE_OVERRIDE', ['default'])[-1] |
| 26 | + |
21 | 27 | def allow_relation(self, *args, **kwargs):
|
22 | 28 | return True
|
23 |
| - |
| 29 | + |
24 | 30 | def allow_syncdb(self, *args, **kwargs):
|
25 | 31 | return None
|
26 |
| - |
| 32 | + |
27 | 33 | def allow_migrate(self, *args, **kwargs):
|
28 | 34 | return None
|
29 | 35 |
|
@@ -83,45 +89,33 @@ def lowest_id_account():
|
83 | 89 | def __init__(self, database, read=True, write=False):
|
84 | 90 | self.read = read
|
85 | 91 | self.write = write
|
| 92 | + self.database = database |
86 | 93 | self.created_db_config = False
|
87 |
| - if isinstance(database, str): |
88 |
| - self.database = database |
89 |
| - elif isinstance(database, dict): |
90 |
| - # Note: this invalidates the docs above. Update them |
91 |
| - # eventually. |
| 94 | + if isinstance(database, dict): |
92 | 95 | self.created_db_config = True
|
93 | 96 | self.unique_db_id = str(uuid4())
|
94 | 97 | connections.databases[self.unique_db_id] = database
|
95 | 98 | self.database = self.unique_db_id
|
96 |
| - else: |
97 |
| - msg = ("database must be an identifier for an existing db, " |
98 |
| - "or a complete configuration.") |
99 |
| - raise ValueError(msg) |
100 |
| - |
| 99 | + |
101 | 100 | def __enter__(self):
|
102 |
| - if not hasattr(THREAD_LOCAL, 'DB_FOR_READ_OVERRIDE'): |
103 |
| - THREAD_LOCAL.DB_FOR_READ_OVERRIDE = ['default'] |
104 |
| - if not hasattr(THREAD_LOCAL, 'DB_FOR_WRITE_OVERRIDE'): |
105 |
| - THREAD_LOCAL.DB_FOR_WRITE_OVERRIDE = ['default'] |
106 |
| - read_db = (self.database if self.read |
107 |
| - else THREAD_LOCAL.DB_FOR_READ_OVERRIDE[-1]) |
108 |
| - write_db = (self.database if self.write |
109 |
| - else THREAD_LOCAL.DB_FOR_WRITE_OVERRIDE[-1]) |
110 |
| - THREAD_LOCAL.DB_FOR_READ_OVERRIDE.append(read_db) |
111 |
| - THREAD_LOCAL.DB_FOR_WRITE_OVERRIDE.append(write_db) |
| 101 | + self.original_read_db = DB_FOR_READ_OVERRIDE.get() |
| 102 | + self.original_write_db = DB_FOR_WRITE_OVERRIDE.get() |
| 103 | + if self.read: |
| 104 | + DB_FOR_READ_OVERRIDE.set(self.database) |
| 105 | + if self.write: |
| 106 | + DB_FOR_WRITE_OVERRIDE.set(self.database) |
112 | 107 | return self
|
113 |
| - |
| 108 | + |
114 | 109 | def __exit__(self, exc_type, exc_value, traceback):
|
115 |
| - THREAD_LOCAL.DB_FOR_READ_OVERRIDE.pop() |
116 |
| - THREAD_LOCAL.DB_FOR_WRITE_OVERRIDE.pop() |
| 110 | + DB_FOR_READ_OVERRIDE.set(self.original_read_db) |
| 111 | + DB_FOR_WRITE_OVERRIDE.set(self.original_write_db) |
117 | 112 | if self.created_db_config:
|
118 | 113 | connections[self.unique_db_id].close()
|
119 | 114 | del connections.databases[self.unique_db_id]
|
120 |
| - |
| 115 | + |
121 | 116 | def __call__(self, querying_func):
|
122 | 117 | @wraps(querying_func)
|
123 | 118 | def inner(*args, **kwargs):
|
124 |
| - # Call the function in our context manager |
125 | 119 | with self:
|
126 | 120 | return querying_func(*args, **kwargs)
|
127 | 121 | return inner
|
0 commit comments