Skip to content

Commit e48dcdc

Browse files
authored
Merge pull request #1 from clearpol/dy/feat/async-support
feat: add support for async operations
2 parents 06f6666 + 8ed828a commit e48dcdc

File tree

1 file changed

+37
-46
lines changed

1 file changed

+37
-46
lines changed

dynamic_db_router/router.py

Lines changed: 37 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,25 @@
1-
import threading
1+
from contextvars import ContextVar
22
from functools import wraps
33
from uuid import uuid4
4-
54
from django.db import connections
65

7-
THREAD_LOCAL = threading.local()
6+
# Define context variables for read and write database settings
7+
# These variables will maintain database preferences per context
8+
DB_FOR_READ_OVERRIDE = ContextVar('DB_FOR_READ_OVERRIDE', default='default')
9+
DB_FOR_WRITE_OVERRIDE = ContextVar('DB_FOR_WRITE_OVERRIDE', default='default')
810

911

10-
class DynamicDbRouter(object):
11-
"""A router that decides what db to read from based on a variable
12-
local to the current thread.
12+
class DynamicDbRouter:
1313
"""
14-
14+
A router that dynamically determines which database to perform read and write operations
15+
on based on the current execution context. It supports both synchronous and asynchronous code.
16+
"""
17+
1518
def db_for_read(self, model, **hints):
16-
return getattr(THREAD_LOCAL, 'DB_FOR_READ_OVERRIDE', ['default'])[-1]
19+
return DB_FOR_READ_OVERRIDE.get()
1720

1821
def db_for_write(self, model, **hints):
19-
return getattr(THREAD_LOCAL, 'DB_FOR_WRITE_OVERRIDE', ['default'])[-1]
22+
return DB_FOR_WRITE_OVERRIDE.get()
2023

2124
def allow_relation(self, *args, **kwargs):
2225
return True
@@ -27,101 +30,89 @@ def allow_syncdb(self, *args, **kwargs):
2730
def allow_migrate(self, *args, **kwargs):
2831
return None
2932

30-
31-
class in_database(object):
32-
"""A decorator and context manager to do queries on a given database.
33-
33+
class in_database:
34+
"""
35+
A decorator and context manager to do queries on a given database.
3436
:type database: str or dict
3537
:param database: The database to run queries on. A string
3638
will route through the matching database in
3739
``django.conf.settings.DATABASES``. A dictionary will set up a
3840
connection with the given configuration and route queries to it.
39-
4041
:type read: bool, optional
4142
:param read: Controls whether database reads will route through
4243
the provided database. If ``False``, reads will route through
4344
the ``'default'`` database. Defaults to ``True``.
44-
4545
:type write: bool, optional
4646
:param write: Controls whether database writes will route to
4747
the provided database. If ``False``, writes will route to
4848
the ``'default'`` database. Defaults to ``False``.
49-
5049
When used as eithe a decorator or a context manager, `in_database`
5150
requires a single argument, which is the name of the database to
5251
route queries to, or a configuration dictionary for a database to
5352
route to.
54-
5553
Usage as a context manager:
56-
5754
.. code-block:: python
58-
5955
from my_django_app.utils import tricky_query
60-
6156
with in_database('Database_A'):
6257
results = tricky_query()
63-
6458
Usage as a decorator:
65-
6659
.. code-block:: python
67-
6860
from my_django_app.models import Account
69-
7061
@in_database('Database_B')
7162
def lowest_id_account():
7263
Account.objects.order_by('-id')[0]
73-
7464
Used with a configuration dictionary:
75-
7665
.. code-block:: python
77-
7866
db_config = {'ENGINE': 'django.db.backends.sqlite3',
7967
'NAME': 'path/to/mydatabase.db'}
8068
with in_database(db_config):
8169
# Run queries
8270
"""
83-
def __init__(self, database, read=True, write=False):
71+
def __init__(self, database: str | dict, read=True, write=False):
8472
self.read = read
8573
self.write = write
74+
self.database = database
8675
self.created_db_config = False
76+
77+
# Handle database parameter either as a string (alias) or as a dict (configuration)
8778
if isinstance(database, str):
8879
self.database = database
8980
elif isinstance(database, dict):
90-
# Note: this invalidates the docs above. Update them
91-
# eventually.
81+
# If it's a dict, create a unique database configuration
9282
self.created_db_config = True
9383
self.unique_db_id = str(uuid4())
9484
connections.databases[self.unique_db_id] = database
9585
self.database = self.unique_db_id
9686
else:
97-
msg = ("database must be an identifier for an existing db, "
98-
"or a complete configuration.")
99-
raise ValueError(msg)
87+
raise ValueError("database must be an identifier (str) for an existing db, "
88+
"or a complete configuration (dict).")
10089

10190
def __enter__(self):
102-
if not hasattr(THREAD_LOCAL, 'DB_FOR_READ_OVERRIDE'):
103-
THREAD_LOCAL.DB_FOR_READ_OVERRIDE = ['default']
104-
if not hasattr(THREAD_LOCAL, 'DB_FOR_WRITE_OVERRIDE'):
105-
THREAD_LOCAL.DB_FOR_WRITE_OVERRIDE = ['default']
106-
read_db = (self.database if self.read
107-
else THREAD_LOCAL.DB_FOR_READ_OVERRIDE[-1])
108-
write_db = (self.database if self.write
109-
else THREAD_LOCAL.DB_FOR_WRITE_OVERRIDE[-1])
110-
THREAD_LOCAL.DB_FOR_READ_OVERRIDE.append(read_db)
111-
THREAD_LOCAL.DB_FOR_WRITE_OVERRIDE.append(write_db)
91+
# Capture the current database settings
92+
self.original_read_db = DB_FOR_READ_OVERRIDE.get()
93+
self.original_write_db = DB_FOR_WRITE_OVERRIDE.get()
94+
95+
# Override the database settings for the duration of the context
96+
if self.read:
97+
DB_FOR_READ_OVERRIDE.set(self.database)
98+
if self.write:
99+
DB_FOR_WRITE_OVERRIDE.set(self.database)
112100
return self
113101

114102
def __exit__(self, exc_type, exc_value, traceback):
115-
THREAD_LOCAL.DB_FOR_READ_OVERRIDE.pop()
116-
THREAD_LOCAL.DB_FOR_WRITE_OVERRIDE.pop()
103+
# Restore the original database settings after the context.
104+
DB_FOR_READ_OVERRIDE.set(self.original_read_db)
105+
DB_FOR_WRITE_OVERRIDE.set(self.original_write_db)
106+
107+
# Close and delete created database configuration
117108
if self.created_db_config:
118109
connections[self.unique_db_id].close()
119110
del connections.databases[self.unique_db_id]
120111

121112
def __call__(self, querying_func):
113+
# Allow the object to be used as a decorator
122114
@wraps(querying_func)
123115
def inner(*args, **kwargs):
124-
# Call the function in our context manager
125116
with self:
126117
return querying_func(*args, **kwargs)
127118
return inner

0 commit comments

Comments
 (0)