Improve set_rollback() behaviour

\## Description

Fixes #6921.

Added tests that fail before and pass afterwards.

Remove the check for `connection.in_atomic_block` to determine if the current request is under a `transaction.atomic` from `ATOMIC_REQUESTS`. Instead, duplicate the method that Django itself uses [in BaseHandler](964dd4f4f2/django/core/handlers/base.py (L64)).

This requires fetching the actual view function from `as_view()`, as seen by the URL resolver / BaseHandler. Since this requires `request`, I've also changed the accesses in `get_exception_handler_context` to be direct attribute accesses rather than `getattr()`. It seems the `getattr` defaults not accessible since `self.request`, `self.args`, and `self.kwargs` are always set in `dispatch()` before `handle_exception()` can ever be called. This is useful since `request` is always needed for the new `set_rollback` logic.
This commit is contained in:
Adam Johnson 2019-09-10 11:21:20 +01:00
parent de497a9bf1
commit 5e5f559b4d
2 changed files with 48 additions and 18 deletions

View File

@ -3,7 +3,7 @@ Provides an APIView class that is the base of all views in REST framework.
"""
from django.conf import settings
from django.core.exceptions import PermissionDenied
from django.db import connection, models, transaction
from django.db import connections, models, transaction
from django.http import Http404
from django.http.response import HttpResponseBase
from django.utils.cache import cc_delim_re, patch_vary_headers
@ -62,10 +62,19 @@ def get_view_description(view, html=False):
return description
def set_rollback():
atomic_requests = connection.settings_dict.get('ATOMIC_REQUESTS', False)
if atomic_requests and connection.in_atomic_block:
transaction.set_rollback(True)
def set_rollback(request):
# We need the actual view func returned by the URL resolver which gets used
# by Django's BaseHandler to determine `non_atomic_requests`. Be cautious
# when fetching it though as it won't be set when views are tested with
# requessts from a RequestFactory.
try:
non_atomic_requests = request.resolver_match.func._non_atomic_requests
except AttributeError:
non_atomic_requests = set()
for db in connections.all():
if db.settings_dict['ATOMIC_REQUESTS'] and db.alias not in non_atomic_requests:
transaction.set_rollback(True, using=db.alias)
def exception_handler(exc, context):
@ -95,7 +104,7 @@ def exception_handler(exc, context):
else:
data = {'detail': exc.detail}
set_rollback()
set_rollback(context['request'])
return Response(data, status=exc.status_code, headers=headers)
return None
@ -223,9 +232,9 @@ class APIView(View):
"""
return {
'view': self,
'args': getattr(self, 'args', ()),
'kwargs': getattr(self, 'kwargs', {}),
'request': getattr(self, 'request', None)
'args': self.args,
'kwargs': self.kwargs,
'request': self.request,
}
def get_view_name(self):

View File

@ -3,7 +3,7 @@ import unittest
from django.conf.urls import url
from django.db import connection, connections, transaction
from django.http import Http404
from django.test import TestCase, TransactionTestCase, override_settings
from django.test import TestCase, override_settings
from rest_framework import status
from rest_framework.exceptions import APIException
@ -39,12 +39,24 @@ class NonAtomicAPIExceptionView(APIView):
return super().dispatch(*args, **kwargs)
def get(self, request, *args, **kwargs):
BasicModel.objects.all()
list(BasicModel.objects.all())
raise Http404
class UrlDecoratedNonAtomicAPIExceptionView(APIView):
def get(self, request, *args, **kwargs):
list(BasicModel.objects.all())
raise Http404
urlpatterns = (
url(r'^$', NonAtomicAPIExceptionView.as_view()),
url(r'^non-atomic-exception$', NonAtomicAPIExceptionView.as_view()),
url(
r'^url-decorated-non-atomic-exception$',
transaction.non_atomic_requests(
UrlDecoratedNonAtomicAPIExceptionView.as_view()
),
),
)
@ -94,8 +106,8 @@ class DBTransactionErrorTests(TestCase):
# 1 - begin savepoint
# 2 - insert
# 3 - release savepoint
with transaction.atomic():
self.assertRaises(Exception, self.view, request)
with transaction.atomic(), self.assertRaises(Exception):
self.view(request)
assert not transaction.get_rollback()
assert BasicModel.objects.count() == 1
@ -135,7 +147,7 @@ class DBTransactionAPIExceptionTests(TestCase):
"'atomic' requires transactions and savepoints."
)
@override_settings(ROOT_URLCONF='tests.test_atomic_requests')
class NonAtomicDBTransactionAPIExceptionTests(TransactionTestCase):
class NonAtomicDBTransactionAPIExceptionTests(TestCase):
def setUp(self):
connections.databases['default']['ATOMIC_REQUESTS'] = True
@ -143,8 +155,17 @@ class NonAtomicDBTransactionAPIExceptionTests(TransactionTestCase):
connections.databases['default']['ATOMIC_REQUESTS'] = False
def test_api_exception_rollback_transaction_non_atomic_view(self):
response = self.client.get('/')
response = self.client.get('/non-atomic-exception')
# without checking connection.in_atomic_block view raises 500
# due attempt to rollback without transaction
assert response.status_code == status.HTTP_404_NOT_FOUND
assert not transaction.get_rollback()
# Check we can still perform DB queries
list(BasicModel.objects.all())
def test_api_exception_rollback_transaction_url_decorated_non_atomic_view(self):
response = self.client.get('/url-decorated-non-atomic-exception')
assert response.status_code == status.HTTP_404_NOT_FOUND
assert not transaction.get_rollback()
# Check we can still perform DB queries
list(BasicModel.objects.all())