From 5c65845b158b112ac44fe8d35b4aed230ddb741a Mon Sep 17 00:00:00 2001 From: Adam Johnson Date: Fri, 13 Dec 2019 20:51:54 +0000 Subject: [PATCH] Revert check for non_atomic_requests, instead rely again on db.in_atomic_block --- rest_framework/views.py | 19 +++++++------------ tests/test_atomic_requests.py | 32 +++++++------------------------- 2 files changed, 14 insertions(+), 37 deletions(-) diff --git a/rest_framework/views.py b/rest_framework/views.py index cadcd2a64..9120079d0 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -62,18 +62,13 @@ def get_view_description(view, html=False): return description -def set_rollback(request): - # We need the actual view func returned by the URL resolver which gets used - # by Django's BaseHandler to determine `non_atomic_requests`. Be cautious - # when fetching it though as it won't be set when views are tested with - # requessts from a RequestFactory. - try: - non_atomic_requests = request.resolver_match.func._non_atomic_requests - except AttributeError: - non_atomic_requests = set() - +def set_rollback(): + # Rollback all connections that have ATOMIC_REQUESTS set, if it looks their + # @atomic for the request was started + # Note this check in_atomic_block may be a false positive due to + # transactions started another way, e.g. through testing with TestCase for db in connections.all(): - if db.settings_dict['ATOMIC_REQUESTS'] and db.alias not in non_atomic_requests: + if db.settings_dict['ATOMIC_REQUESTS'] and db.in_atomic_block: transaction.set_rollback(True, using=db.alias) @@ -104,7 +99,7 @@ def exception_handler(exc, context): else: data = {'detail': exc.detail} - set_rollback(context['request']) + set_rollback() return Response(data, status=exc.status_code, headers=headers) return None diff --git a/tests/test_atomic_requests.py b/tests/test_atomic_requests.py index c158169db..e6969bdfa 100644 --- a/tests/test_atomic_requests.py +++ b/tests/test_atomic_requests.py @@ -3,7 +3,7 @@ import unittest from django.conf.urls import url from django.db import connection, connections, transaction from django.http import Http404 -from django.test import TestCase, override_settings +from django.test import TestCase, TransactionTestCase, override_settings from rest_framework import status from rest_framework.exceptions import APIException @@ -43,20 +43,8 @@ class NonAtomicAPIExceptionView(APIView): raise Http404 -class UrlDecoratedNonAtomicAPIExceptionView(APIView): - def get(self, request, *args, **kwargs): - list(BasicModel.objects.all()) - raise Http404 - - urlpatterns = ( url(r'^non-atomic-exception$', NonAtomicAPIExceptionView.as_view()), - url( - r'^url-decorated-non-atomic-exception$', - transaction.non_atomic_requests( - UrlDecoratedNonAtomicAPIExceptionView.as_view() - ), - ), ) @@ -147,25 +135,19 @@ class DBTransactionAPIExceptionTests(TestCase): "'atomic' requires transactions and savepoints." ) @override_settings(ROOT_URLCONF='tests.test_atomic_requests') -class NonAtomicDBTransactionAPIExceptionTests(TestCase): +class NonAtomicDBTransactionAPIExceptionTests(TransactionTestCase): def setUp(self): connections.databases['default']['ATOMIC_REQUESTS'] = True - def tearDown(self): - connections.databases['default']['ATOMIC_REQUESTS'] = False + @self.addCleanup + def restore_atomic_requests(): + connections.databases['default']['ATOMIC_REQUESTS'] = False def test_api_exception_rollback_transaction_non_atomic_view(self): response = self.client.get('/non-atomic-exception') + # without check for db.in_atomic_block, would raise 500 due to attempt + # to rollback without transaction assert response.status_code == status.HTTP_404_NOT_FOUND - assert not transaction.get_rollback() - # Check we can still perform DB queries - list(BasicModel.objects.all()) - - def test_api_exception_rollback_transaction_url_decorated_non_atomic_view(self): - response = self.client.get('/url-decorated-non-atomic-exception') - - assert response.status_code == status.HTTP_404_NOT_FOUND - assert not transaction.get_rollback() # Check we can still perform DB queries list(BasicModel.objects.all())