diff --git a/rest_framework/views.py b/rest_framework/views.py index d1b5e4ed9..5b0622069 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -3,7 +3,7 @@ Provides an APIView class that is the base of all views in REST framework. """ from django.conf import settings from django.core.exceptions import PermissionDenied -from django.db import connection, models, transaction +from django.db import connections, models from django.http import Http404 from django.http.response import HttpResponseBase from django.utils.cache import cc_delim_re, patch_vary_headers @@ -63,9 +63,9 @@ def get_view_description(view, html=False): def set_rollback(): - atomic_requests = connection.settings_dict.get('ATOMIC_REQUESTS', False) - if atomic_requests and connection.in_atomic_block: - transaction.set_rollback(True) + for db in connections.all(): + if db.settings_dict['ATOMIC_REQUESTS'] and db.in_atomic_block: + db.set_rollback(True) def exception_handler(exc, context): diff --git a/tests/conftest.py b/tests/conftest.py index ac29e4a42..cc32cc637 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,6 +24,10 @@ def pytest_configure(config): 'default': { 'ENGINE': 'django.db.backends.sqlite3', 'NAME': ':memory:' + }, + 'secondary': { + 'ENGINE': 'django.db.backends.sqlite3', + 'NAME': ':memory:' } }, SITE_ID=1, diff --git a/tests/test_atomic_requests.py b/tests/test_atomic_requests.py index 15b41e02f..beda5cba1 100644 --- a/tests/test_atomic_requests.py +++ b/tests/test_atomic_requests.py @@ -130,6 +130,41 @@ class DBTransactionAPIExceptionTests(TestCase): assert BasicModel.objects.count() == 0 +@unittest.skipUnless( + connection.features.uses_savepoints, + "'atomic' requires transactions and savepoints." +) +class MultiDBTransactionAPIExceptionTests(TestCase): + databases = '__all__' + + def setUp(self): + self.view = APIExceptionView.as_view() + connections.databases['default']['ATOMIC_REQUESTS'] = True + connections.databases['secondary']['ATOMIC_REQUESTS'] = True + + def tearDown(self): + connections.databases['default']['ATOMIC_REQUESTS'] = False + connections.databases['secondary']['ATOMIC_REQUESTS'] = False + + def test_api_exception_rollback_transaction(self): + """ + Transaction is rollbacked by our transaction atomic block. + """ + request = factory.post('/') + num_queries = 4 if connection.features.can_release_savepoints else 3 + with self.assertNumQueries(num_queries): + # 1 - begin savepoint + # 2 - insert + # 3 - rollback savepoint + # 4 - release savepoint + with transaction.atomic(), transaction.atomic(using='secondary'): + response = self.view(request) + assert transaction.get_rollback() + assert transaction.get_rollback(using='secondary') + assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert BasicModel.objects.count() == 0 + + @unittest.skipUnless( connection.features.uses_savepoints, "'atomic' requires transactions and savepoints."