Skip to content

Add support for asynchronous workflows with ContextVar #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 37 additions & 46 deletions dynamic_db_router/router.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
import threading
from contextvars import ContextVar
from functools import wraps
from uuid import uuid4

from django.db import connections

THREAD_LOCAL = threading.local()
# Define context variables for read and write database settings
# These variables will maintain database preferences per context
DB_FOR_READ_OVERRIDE = ContextVar('DB_FOR_READ_OVERRIDE', default='default')
DB_FOR_WRITE_OVERRIDE = ContextVar('DB_FOR_WRITE_OVERRIDE', default='default')


class DynamicDbRouter(object):
"""A router that decides what db to read from based on a variable
local to the current thread.
class DynamicDbRouter:
"""

A router that dynamically determines which database to perform read and write operations
on based on the current execution context. It supports both synchronous and asynchronous code.
"""

def db_for_read(self, model, **hints):
return getattr(THREAD_LOCAL, 'DB_FOR_READ_OVERRIDE', ['default'])[-1]
return DB_FOR_READ_OVERRIDE.get()

def db_for_write(self, model, **hints):
return getattr(THREAD_LOCAL, 'DB_FOR_WRITE_OVERRIDE', ['default'])[-1]
return DB_FOR_WRITE_OVERRIDE.get()

def allow_relation(self, *args, **kwargs):
return True
Expand All @@ -27,101 +30,89 @@ def allow_syncdb(self, *args, **kwargs):
def allow_migrate(self, *args, **kwargs):
return None


class in_database(object):
"""A decorator and context manager to do queries on a given database.

class in_database:
"""
A decorator and context manager to do queries on a given database.
:type database: str or dict
:param database: The database to run queries on. A string
will route through the matching database in
``django.conf.settings.DATABASES``. A dictionary will set up a
connection with the given configuration and route queries to it.

:type read: bool, optional
:param read: Controls whether database reads will route through
the provided database. If ``False``, reads will route through
the ``'default'`` database. Defaults to ``True``.

:type write: bool, optional
:param write: Controls whether database writes will route to
the provided database. If ``False``, writes will route to
the ``'default'`` database. Defaults to ``False``.

When used as eithe a decorator or a context manager, `in_database`
requires a single argument, which is the name of the database to
route queries to, or a configuration dictionary for a database to
route to.

Usage as a context manager:

.. code-block:: python

from my_django_app.utils import tricky_query

with in_database('Database_A'):
results = tricky_query()

Usage as a decorator:

.. code-block:: python

from my_django_app.models import Account

@in_database('Database_B')
def lowest_id_account():
Account.objects.order_by('-id')[0]

Used with a configuration dictionary:

.. code-block:: python

db_config = {'ENGINE': 'django.db.backends.sqlite3',
'NAME': 'path/to/mydatabase.db'}
with in_database(db_config):
# Run queries
"""
def __init__(self, database, read=True, write=False):
def __init__(self, database: str | dict, read=True, write=False):
self.read = read
self.write = write
self.database = database
self.created_db_config = False

# Handle database parameter either as a string (alias) or as a dict (configuration)
if isinstance(database, str):
self.database = database
elif isinstance(database, dict):
# Note: this invalidates the docs above. Update them
# eventually.
# If it's a dict, create a unique database configuration
self.created_db_config = True
self.unique_db_id = str(uuid4())
connections.databases[self.unique_db_id] = database
self.database = self.unique_db_id
else:
msg = ("database must be an identifier for an existing db, "
"or a complete configuration.")
raise ValueError(msg)
raise ValueError("database must be an identifier (str) for an existing db, "
"or a complete configuration (dict).")

def __enter__(self):
if not hasattr(THREAD_LOCAL, 'DB_FOR_READ_OVERRIDE'):
THREAD_LOCAL.DB_FOR_READ_OVERRIDE = ['default']
if not hasattr(THREAD_LOCAL, 'DB_FOR_WRITE_OVERRIDE'):
THREAD_LOCAL.DB_FOR_WRITE_OVERRIDE = ['default']
read_db = (self.database if self.read
else THREAD_LOCAL.DB_FOR_READ_OVERRIDE[-1])
write_db = (self.database if self.write
else THREAD_LOCAL.DB_FOR_WRITE_OVERRIDE[-1])
THREAD_LOCAL.DB_FOR_READ_OVERRIDE.append(read_db)
THREAD_LOCAL.DB_FOR_WRITE_OVERRIDE.append(write_db)
# Capture the current database settings
self.original_read_db = DB_FOR_READ_OVERRIDE.get()
self.original_write_db = DB_FOR_WRITE_OVERRIDE.get()

# Override the database settings for the duration of the context
if self.read:
DB_FOR_READ_OVERRIDE.set(self.database)
if self.write:
DB_FOR_WRITE_OVERRIDE.set(self.database)
return self

def __exit__(self, exc_type, exc_value, traceback):
THREAD_LOCAL.DB_FOR_READ_OVERRIDE.pop()
THREAD_LOCAL.DB_FOR_WRITE_OVERRIDE.pop()
# Restore the original database settings after the context.
DB_FOR_READ_OVERRIDE.set(self.original_read_db)
DB_FOR_WRITE_OVERRIDE.set(self.original_write_db)

# Close and delete created database configuration
if self.created_db_config:
connections[self.unique_db_id].close()
del connections.databases[self.unique_db_id]

def __call__(self, querying_func):
# Allow the object to be used as a decorator
@wraps(querying_func)
def inner(*args, **kwargs):
# Call the function in our context manager
with self:
return querying_func(*args, **kwargs)
return inner