1
- import threading
1
+ from contextvars import ContextVar
2
2
from functools import wraps
3
3
from uuid import uuid4
4
-
5
4
from django .db import connections
6
5
7
- THREAD_LOCAL = threading .local ()
6
+ # Define context variables for read and write database settings
7
+ # These variables will maintain database preferences per context
8
+ DB_FOR_READ_OVERRIDE = ContextVar ('DB_FOR_READ_OVERRIDE' , default = 'default' )
9
+ DB_FOR_WRITE_OVERRIDE = ContextVar ('DB_FOR_WRITE_OVERRIDE' , default = 'default' )
8
10
9
11
10
- class DynamicDbRouter (object ):
11
- """A router that decides what db to read from based on a variable
12
- local to the current thread.
12
+ class DynamicDbRouter :
13
13
"""
14
-
14
+ A router that dynamically determines which database to perform read and write operations
15
+ on based on the current execution context. It supports both synchronous and asynchronous code.
16
+ """
17
+
15
18
def db_for_read (self , model , ** hints ):
16
- return getattr ( THREAD_LOCAL , ' DB_FOR_READ_OVERRIDE' , [ 'default' ])[ - 1 ]
19
+ return DB_FOR_READ_OVERRIDE . get ()
17
20
18
21
def db_for_write (self , model , ** hints ):
19
- return getattr ( THREAD_LOCAL , ' DB_FOR_WRITE_OVERRIDE' , [ 'default' ])[ - 1 ]
22
+ return DB_FOR_WRITE_OVERRIDE . get ()
20
23
21
24
def allow_relation (self , * args , ** kwargs ):
22
25
return True
@@ -27,101 +30,89 @@ def allow_syncdb(self, *args, **kwargs):
27
30
def allow_migrate (self , * args , ** kwargs ):
28
31
return None
29
32
30
-
31
- class in_database (object ):
32
- """A decorator and context manager to do queries on a given database.
33
-
33
+ class in_database :
34
+ """
35
+ A decorator and context manager to do queries on a given database.
34
36
:type database: str or dict
35
37
:param database: The database to run queries on. A string
36
38
will route through the matching database in
37
39
``django.conf.settings.DATABASES``. A dictionary will set up a
38
40
connection with the given configuration and route queries to it.
39
-
40
41
:type read: bool, optional
41
42
:param read: Controls whether database reads will route through
42
43
the provided database. If ``False``, reads will route through
43
44
the ``'default'`` database. Defaults to ``True``.
44
-
45
45
:type write: bool, optional
46
46
:param write: Controls whether database writes will route to
47
47
the provided database. If ``False``, writes will route to
48
48
the ``'default'`` database. Defaults to ``False``.
49
-
50
49
When used as eithe a decorator or a context manager, `in_database`
51
50
requires a single argument, which is the name of the database to
52
51
route queries to, or a configuration dictionary for a database to
53
52
route to.
54
-
55
53
Usage as a context manager:
56
-
57
54
.. code-block:: python
58
-
59
55
from my_django_app.utils import tricky_query
60
-
61
56
with in_database('Database_A'):
62
57
results = tricky_query()
63
-
64
58
Usage as a decorator:
65
-
66
59
.. code-block:: python
67
-
68
60
from my_django_app.models import Account
69
-
70
61
@in_database('Database_B')
71
62
def lowest_id_account():
72
63
Account.objects.order_by('-id')[0]
73
-
74
64
Used with a configuration dictionary:
75
-
76
65
.. code-block:: python
77
-
78
66
db_config = {'ENGINE': 'django.db.backends.sqlite3',
79
67
'NAME': 'path/to/mydatabase.db'}
80
68
with in_database(db_config):
81
69
# Run queries
82
70
"""
83
- def __init__ (self , database , read = True , write = False ):
71
+ def __init__ (self , database : str | dict , read = True , write = False ):
84
72
self .read = read
85
73
self .write = write
74
+ self .database = database
86
75
self .created_db_config = False
76
+
77
+ # Handle database parameter either as a string (alias) or as a dict (configuration)
87
78
if isinstance (database , str ):
88
79
self .database = database
89
80
elif isinstance (database , dict ):
90
- # Note: this invalidates the docs above. Update them
91
- # eventually.
81
+ # If it's a dict, create a unique database configuration
92
82
self .created_db_config = True
93
83
self .unique_db_id = str (uuid4 ())
94
84
connections .databases [self .unique_db_id ] = database
95
85
self .database = self .unique_db_id
96
86
else :
97
- msg = ("database must be an identifier for an existing db, "
98
- "or a complete configuration." )
99
- raise ValueError (msg )
87
+ raise ValueError ("database must be an identifier (str) for an existing db, "
88
+ "or a complete configuration (dict)." )
100
89
101
90
def __enter__ (self ):
102
- if not hasattr (THREAD_LOCAL , 'DB_FOR_READ_OVERRIDE' ):
103
- THREAD_LOCAL .DB_FOR_READ_OVERRIDE = ['default' ]
104
- if not hasattr (THREAD_LOCAL , 'DB_FOR_WRITE_OVERRIDE' ):
105
- THREAD_LOCAL .DB_FOR_WRITE_OVERRIDE = ['default' ]
106
- read_db = (self .database if self .read
107
- else THREAD_LOCAL .DB_FOR_READ_OVERRIDE [- 1 ])
108
- write_db = (self .database if self .write
109
- else THREAD_LOCAL .DB_FOR_WRITE_OVERRIDE [- 1 ])
110
- THREAD_LOCAL .DB_FOR_READ_OVERRIDE .append (read_db )
111
- THREAD_LOCAL .DB_FOR_WRITE_OVERRIDE .append (write_db )
91
+ # Capture the current database settings
92
+ self .original_read_db = DB_FOR_READ_OVERRIDE .get ()
93
+ self .original_write_db = DB_FOR_WRITE_OVERRIDE .get ()
94
+
95
+ # Override the database settings for the duration of the context
96
+ if self .read :
97
+ DB_FOR_READ_OVERRIDE .set (self .database )
98
+ if self .write :
99
+ DB_FOR_WRITE_OVERRIDE .set (self .database )
112
100
return self
113
101
114
102
def __exit__ (self , exc_type , exc_value , traceback ):
115
- THREAD_LOCAL .DB_FOR_READ_OVERRIDE .pop ()
116
- THREAD_LOCAL .DB_FOR_WRITE_OVERRIDE .pop ()
103
+ # Restore the original database settings after the context.
104
+ DB_FOR_READ_OVERRIDE .set (self .original_read_db )
105
+ DB_FOR_WRITE_OVERRIDE .set (self .original_write_db )
106
+
107
+ # Close and delete created database configuration
117
108
if self .created_db_config :
118
109
connections [self .unique_db_id ].close ()
119
110
del connections .databases [self .unique_db_id ]
120
111
121
112
def __call__ (self , querying_func ):
113
+ # Allow the object to be used as a decorator
122
114
@wraps (querying_func )
123
115
def inner (* args , ** kwargs ):
124
- # Call the function in our context manager
125
116
with self :
126
117
return querying_func (* args , ** kwargs )
127
118
return inner
0 commit comments