diff --git a/docs/helpers.rst b/docs/helpers.rst index 7c60f9005..f3096caf0 100644 --- a/docs/helpers.rst +++ b/docs/helpers.rst @@ -191,6 +191,26 @@ transaction support. This is only required for fixtures which need database access themselves. A test function would normally use the :py:func:`~pytest.mark.django_db` mark to signal it needs the database. +``shared_db_wrapper`` +~~~~~~~~~~~~~~~~~~~~~ + +This fixture can be used to create long-lived state in the database. +It's meant to be used from fixtures with scope bigger than ``function``. +It provides a context manager that will create a new database savepoint for you, +and will take care to revert it when your fixture gets cleaned up. + +At the moment it does not work with ``transactional_db``, +as the fixture itself depends on transactions. +It also needs Django >= 1.8, as earlier versions close DB connections between tests. + +Example usage:: + + @pytest.fixture(scope='module') + def some_users(request, shared_db_wrapper): + with shared_db_wrapper(request): + return [User.objects.create(username='no {}'.format(i)) + for i in range(1000)] + ``live_server`` ~~~~~~~~~~~~~~~ diff --git a/pytest_django/fixtures.py b/pytest_django/fixtures.py index cae7d4763..b91ecad6e 100644 --- a/pytest_django/fixtures.py +++ b/pytest_django/fixtures.py @@ -2,7 +2,9 @@ from __future__ import with_statement +from contextlib import contextmanager import os +import sys import warnings import pytest @@ -13,7 +15,8 @@ from .django_compat import is_django_unittest from .lazy_django import get_django_version, skip_if_no_django -__all__ = ['_django_db_setup', 'db', 'transactional_db', 'admin_user', +__all__ = ['_django_db_setup', 'db', 'transactional_db', 'shared_db_wrapper', + 'admin_user', 'django_user_model', 'django_username_field', 'client', 'admin_client', 'rf', 'settings', 'live_server', '_live_server_helper'] @@ -195,6 +198,58 @@ def transactional_db(request, _django_db_setup, _django_cursor_wrapper): return _django_db_fixture_helper(True, request, _django_cursor_wrapper) +@pytest.fixture(scope='session') +def shared_db_wrapper(_django_db_setup, _django_cursor_wrapper): + """Wrapper for common database initialization code. + + This fixture provides a context manager that let's you access the database + from a transaction spanning multiple tests. + """ + from django.db import connection, transaction + + if get_django_version() < (1, 8): + raise Exception('shared_db_wrapper is only supported on Django >= 1.8.') + + class DummyException(Exception): + """Dummy for use with Atomic.__exit__.""" + + @contextmanager + def wrapper(request): + # We need to take the request + # to bind finalization to the place where this is used + if 'transactional_db' in request.funcargnames: + raise Exception( + 'shared_db_wrapper cannot be used with `transactional_db`.') + + with _django_cursor_wrapper: + if not connection.features.supports_transactions: + raise Exception( + "shared_db_wrapper cannot be used when " + "the database doesn't support transactions.") + + # Use atomic instead of calling .savepoint* directly. + # This way works for both top-level transactions and "subtransactions". + atomic = transaction.atomic() + + def finalize(): + # dummy exception makes `atomic` rollback the savepoint + with _django_cursor_wrapper: + atomic.__exit__(DummyException, DummyException(), None) + + try: + _django_cursor_wrapper.enable() + atomic.__enter__() + yield + request.addfinalizer(finalize) + except: + atomic.__exit__(*sys.exc_info()) + raise + finally: + _django_cursor_wrapper.restore() + + return wrapper + + @pytest.fixture() def client(): """A Django test client instance.""" diff --git a/pytest_django/plugin.py b/pytest_django/plugin.py index 2484eedac..6ed28747f 100644 --- a/pytest_django/plugin.py +++ b/pytest_django/plugin.py @@ -17,14 +17,14 @@ from .django_compat import is_django_unittest from .fixtures import (_django_db_setup, _live_server_helper, admin_client, admin_user, client, db, django_user_model, - django_username_field, live_server, rf, settings, - transactional_db) + django_username_field, live_server, rf, shared_db_wrapper, + settings, transactional_db) from .lazy_django import django_settings_is_configured, skip_if_no_django # Silence linters for imported fixtures. (_django_db_setup, _live_server_helper, admin_client, admin_user, client, db, django_user_model, django_username_field, live_server, rf, settings, - transactional_db) + shared_db_wrapper, transactional_db) SETTINGS_MODULE_ENV = 'DJANGO_SETTINGS_MODULE' diff --git a/tests/test_database.py b/tests/test_database.py index adbc51736..22e936702 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -4,6 +4,7 @@ from django.db import connection, transaction from django.test.testcases import connections_support_transactions +from pytest_django.lazy_django import get_django_version from pytest_django_test.app.models import Item @@ -51,6 +52,77 @@ def test_noaccess_fixture(noaccess): pass +@pytest.mark.skipif( + get_django_version() < (1, 8), + reason="shared_db_wrapper needs at least Django 1.8") +def test_shared_db_wrapper(django_testdir): + django_testdir.create_test_module(''' + from .app.models import Item + import pytest + from uuid import uuid4 + + @pytest.fixture(scope='session') + def session_item(request, shared_db_wrapper): + with shared_db_wrapper(request): + return Item.objects.create(name='session-' + uuid4().hex) + + @pytest.fixture(scope='module') + def module_item(request, shared_db_wrapper): + with shared_db_wrapper(request): + return Item.objects.create(name='module-' + uuid4().hex) + + @pytest.fixture(scope='class') + def class_item(request, shared_db_wrapper): + with shared_db_wrapper(request): + return Item.objects.create(name='class-' + uuid4().hex) + + @pytest.fixture + def function_item(db): + return Item.objects.create(name='function-' + uuid4().hex) + + class TestItems: + def test_save_the_items( + self, session_item, module_item, class_item, + function_item, db): + global _session_item + global _module_item + global _class_item + assert session_item.pk + assert module_item.pk + assert class_item.pk + _session_item = session_item + _module_item = module_item + _class_item = class_item + + def test_mixing_with_non_db_tests(self): + pass + + def test_accessing_the_same_items( + self, db, session_item, module_item, class_item): + assert _session_item.name == session_item.name + Item.objects.get(pk=_session_item.pk) + assert _module_item.name == module_item.name + Item.objects.get(pk=_module_item.pk) + assert _class_item.name == class_item.name + Item.objects.get(pk=_class_item.pk) + + def test_mixing_with_other_db_tests(db): + Item.objects.get(name=_module_item.name) + assert Item.objects.filter(name__startswith='function').count() == 0 + + class TestSharing: + def test_sharing_some_items( + self, db, session_item, module_item, class_item, + function_item): + assert _session_item.name == session_item.name + assert _module_item.name == module_item.name + assert _class_item.name != class_item.name + assert Item.objects.filter(name__startswith='function').count() == 1 + ''') + result = django_testdir.runpytest_subprocess('-v', '-s', '--reuse-db') + assert result.ret == 0 + + class TestDatabaseFixtures: """Tests for the db and transactional_db fixtures"""