Skip to content

Commit cd8d593

Browse files
feat (add support for async operations)
1 parent 06f6666 commit cd8d593

File tree

1 file changed

+26
-32
lines changed

1 file changed

+26
-32
lines changed

dynamic_db_router/router.py

Lines changed: 26 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,35 @@
11
import threading
22
from functools import wraps
33
from uuid import uuid4
4+
from contextvars import ContextVar
45

56
from django.db import connections
67

78
THREAD_LOCAL = threading.local()
89

10+
DB_FOR_READ_OVERRIDE = ContextVar('DB_FOR_READ_OVERRIDE', default='default')
11+
DB_FOR_WRITE_OVERRIDE = ContextVar('DB_FOR_WRITE_OVERRIDE', default='default')
12+
913

1014
class DynamicDbRouter(object):
1115
"""A router that decides what db to read from based on a variable
1216
local to the current thread.
1317
"""
14-
18+
1519
def db_for_read(self, model, **hints):
16-
return getattr(THREAD_LOCAL, 'DB_FOR_READ_OVERRIDE', ['default'])[-1]
17-
20+
return DB_FOR_READ_OVERRIDE.get()
21+
# return getattr(THREAD_LOCAL, 'DB_FOR_READ_OVERRIDE', ['default'])[-1]
22+
1823
def db_for_write(self, model, **hints):
19-
return getattr(THREAD_LOCAL, 'DB_FOR_WRITE_OVERRIDE', ['default'])[-1]
20-
24+
return DB_FOR_WRITE_OVERRIDE.get()
25+
# return getattr(THREAD_LOCAL, 'DB_FOR_WRITE_OVERRIDE', ['default'])[-1]
26+
2127
def allow_relation(self, *args, **kwargs):
2228
return True
23-
29+
2430
def allow_syncdb(self, *args, **kwargs):
2531
return None
26-
32+
2733
def allow_migrate(self, *args, **kwargs):
2834
return None
2935

@@ -83,45 +89,33 @@ def lowest_id_account():
8389
def __init__(self, database, read=True, write=False):
8490
self.read = read
8591
self.write = write
92+
self.database = database
8693
self.created_db_config = False
87-
if isinstance(database, str):
88-
self.database = database
89-
elif isinstance(database, dict):
90-
# Note: this invalidates the docs above. Update them
91-
# eventually.
94+
if isinstance(database, dict):
9295
self.created_db_config = True
9396
self.unique_db_id = str(uuid4())
9497
connections.databases[self.unique_db_id] = database
9598
self.database = self.unique_db_id
96-
else:
97-
msg = ("database must be an identifier for an existing db, "
98-
"or a complete configuration.")
99-
raise ValueError(msg)
100-
99+
101100
def __enter__(self):
102-
if not hasattr(THREAD_LOCAL, 'DB_FOR_READ_OVERRIDE'):
103-
THREAD_LOCAL.DB_FOR_READ_OVERRIDE = ['default']
104-
if not hasattr(THREAD_LOCAL, 'DB_FOR_WRITE_OVERRIDE'):
105-
THREAD_LOCAL.DB_FOR_WRITE_OVERRIDE = ['default']
106-
read_db = (self.database if self.read
107-
else THREAD_LOCAL.DB_FOR_READ_OVERRIDE[-1])
108-
write_db = (self.database if self.write
109-
else THREAD_LOCAL.DB_FOR_WRITE_OVERRIDE[-1])
110-
THREAD_LOCAL.DB_FOR_READ_OVERRIDE.append(read_db)
111-
THREAD_LOCAL.DB_FOR_WRITE_OVERRIDE.append(write_db)
101+
self.original_read_db = DB_FOR_READ_OVERRIDE.get()
102+
self.original_write_db = DB_FOR_WRITE_OVERRIDE.get()
103+
if self.read:
104+
DB_FOR_READ_OVERRIDE.set(self.database)
105+
if self.write:
106+
DB_FOR_WRITE_OVERRIDE.set(self.database)
112107
return self
113-
108+
114109
def __exit__(self, exc_type, exc_value, traceback):
115-
THREAD_LOCAL.DB_FOR_READ_OVERRIDE.pop()
116-
THREAD_LOCAL.DB_FOR_WRITE_OVERRIDE.pop()
110+
DB_FOR_READ_OVERRIDE.set(self.original_read_db)
111+
DB_FOR_WRITE_OVERRIDE.set(self.original_write_db)
117112
if self.created_db_config:
118113
connections[self.unique_db_id].close()
119114
del connections.databases[self.unique_db_id]
120-
115+
121116
def __call__(self, querying_func):
122117
@wraps(querying_func)
123118
def inner(*args, **kwargs):
124-
# Call the function in our context manager
125119
with self:
126120
return querying_func(*args, **kwargs)
127121
return inner

0 commit comments

Comments
 (0)