diff --git a/rest_framework/views.py b/rest_framework/views.py index 69db053d6..cadcd2a64 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -3,7 +3,7 @@ Provides an APIView class that is the base of all views in REST framework. """ from django.conf import settings from django.core.exceptions import PermissionDenied -from django.db import connection, models, transaction +from django.db import connections, models, transaction from django.http import Http404 from django.http.response import HttpResponseBase from django.utils.cache import cc_delim_re, patch_vary_headers @@ -62,10 +62,19 @@ def get_view_description(view, html=False): return description -def set_rollback(): - atomic_requests = connection.settings_dict.get('ATOMIC_REQUESTS', False) - if atomic_requests and connection.in_atomic_block: - transaction.set_rollback(True) +def set_rollback(request): + # We need the actual view func returned by the URL resolver which gets used + # by Django's BaseHandler to determine `non_atomic_requests`. Be cautious + # when fetching it though as it won't be set when views are tested with + # requessts from a RequestFactory. + try: + non_atomic_requests = request.resolver_match.func._non_atomic_requests + except AttributeError: + non_atomic_requests = set() + + for db in connections.all(): + if db.settings_dict['ATOMIC_REQUESTS'] and db.alias not in non_atomic_requests: + transaction.set_rollback(True, using=db.alias) def exception_handler(exc, context): @@ -95,7 +104,7 @@ def exception_handler(exc, context): else: data = {'detail': exc.detail} - set_rollback() + set_rollback(context['request']) return Response(data, status=exc.status_code, headers=headers) return None @@ -223,9 +232,9 @@ class APIView(View): """ return { 'view': self, - 'args': getattr(self, 'args', ()), - 'kwargs': getattr(self, 'kwargs', {}), - 'request': getattr(self, 'request', None) + 'args': self.args, + 'kwargs': self.kwargs, + 'request': self.request, } def get_view_name(self): diff --git a/tests/test_atomic_requests.py b/tests/test_atomic_requests.py index de04d2c06..c158169db 100644 --- a/tests/test_atomic_requests.py +++ b/tests/test_atomic_requests.py @@ -3,7 +3,7 @@ import unittest from django.conf.urls import url from django.db import connection, connections, transaction from django.http import Http404 -from django.test import TestCase, TransactionTestCase, override_settings +from django.test import TestCase, override_settings from rest_framework import status from rest_framework.exceptions import APIException @@ -39,12 +39,24 @@ class NonAtomicAPIExceptionView(APIView): return super().dispatch(*args, **kwargs) def get(self, request, *args, **kwargs): - BasicModel.objects.all() + list(BasicModel.objects.all()) + raise Http404 + + +class UrlDecoratedNonAtomicAPIExceptionView(APIView): + def get(self, request, *args, **kwargs): + list(BasicModel.objects.all()) raise Http404 urlpatterns = ( - url(r'^$', NonAtomicAPIExceptionView.as_view()), + url(r'^non-atomic-exception$', NonAtomicAPIExceptionView.as_view()), + url( + r'^url-decorated-non-atomic-exception$', + transaction.non_atomic_requests( + UrlDecoratedNonAtomicAPIExceptionView.as_view() + ), + ), ) @@ -94,8 +106,8 @@ class DBTransactionErrorTests(TestCase): # 1 - begin savepoint # 2 - insert # 3 - release savepoint - with transaction.atomic(): - self.assertRaises(Exception, self.view, request) + with transaction.atomic(), self.assertRaises(Exception): + self.view(request) assert not transaction.get_rollback() assert BasicModel.objects.count() == 1 @@ -135,7 +147,7 @@ class DBTransactionAPIExceptionTests(TestCase): "'atomic' requires transactions and savepoints." ) @override_settings(ROOT_URLCONF='tests.test_atomic_requests') -class NonAtomicDBTransactionAPIExceptionTests(TransactionTestCase): +class NonAtomicDBTransactionAPIExceptionTests(TestCase): def setUp(self): connections.databases['default']['ATOMIC_REQUESTS'] = True @@ -143,8 +155,17 @@ class NonAtomicDBTransactionAPIExceptionTests(TransactionTestCase): connections.databases['default']['ATOMIC_REQUESTS'] = False def test_api_exception_rollback_transaction_non_atomic_view(self): - response = self.client.get('/') + response = self.client.get('/non-atomic-exception') - # without checking connection.in_atomic_block view raises 500 - # due attempt to rollback without transaction assert response.status_code == status.HTTP_404_NOT_FOUND + assert not transaction.get_rollback() + # Check we can still perform DB queries + list(BasicModel.objects.all()) + + def test_api_exception_rollback_transaction_url_decorated_non_atomic_view(self): + response = self.client.get('/url-decorated-non-atomic-exception') + + assert response.status_code == status.HTTP_404_NOT_FOUND + assert not transaction.get_rollback() + # Check we can still perform DB queries + list(BasicModel.objects.all())