mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-09-17 09:42:29 +03:00
Support atomic transaction views in multiple database connections
This commit is contained in:
parent
0618fa88e1
commit
ac3e8495e7
|
@ -16,30 +16,35 @@ factory = APIRequestFactory()
|
||||||
|
|
||||||
|
|
||||||
class BasicView(APIView):
|
class BasicView(APIView):
|
||||||
|
database = 'default'
|
||||||
|
|
||||||
|
def get_queryset(self):
|
||||||
|
return BasicModel.objects.using(self.database).all()
|
||||||
|
|
||||||
def post(self, request, *args, **kwargs):
|
def post(self, request, *args, **kwargs):
|
||||||
BasicModel.objects.create()
|
self.get_queryset().create()
|
||||||
return Response({'method': 'GET'})
|
return Response({'method': 'GET'})
|
||||||
|
|
||||||
|
|
||||||
class ErrorView(APIView):
|
class ErrorView(BasicView):
|
||||||
def post(self, request, *args, **kwargs):
|
def post(self, request, *args, **kwargs):
|
||||||
BasicModel.objects.create()
|
self.get_queryset().create()
|
||||||
raise Exception
|
raise Exception
|
||||||
|
|
||||||
|
|
||||||
class APIExceptionView(APIView):
|
class APIExceptionView(BasicView):
|
||||||
def post(self, request, *args, **kwargs):
|
def post(self, request, *args, **kwargs):
|
||||||
BasicModel.objects.create()
|
self.get_queryset().create()
|
||||||
raise APIException
|
raise APIException
|
||||||
|
|
||||||
|
|
||||||
class NonAtomicAPIExceptionView(APIView):
|
class NonAtomicAPIExceptionView(BasicView):
|
||||||
@transaction.non_atomic_requests
|
@transaction.non_atomic_requests
|
||||||
def dispatch(self, *args, **kwargs):
|
def dispatch(self, *args, **kwargs):
|
||||||
return super().dispatch(*args, **kwargs)
|
return super().dispatch(*args, **kwargs)
|
||||||
|
|
||||||
def get(self, request, *args, **kwargs):
|
def get(self, request, *args, **kwargs):
|
||||||
BasicModel.objects.all()
|
self.get_queryset()
|
||||||
raise Http404
|
raise Http404
|
||||||
|
|
||||||
|
|
||||||
|
@ -53,34 +58,52 @@ urlpatterns = (
|
||||||
"'atomic' requires transactions and savepoints."
|
"'atomic' requires transactions and savepoints."
|
||||||
)
|
)
|
||||||
class DBTransactionTests(TestCase):
|
class DBTransactionTests(TestCase):
|
||||||
|
databases = '__all__'
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.view = BasicView.as_view()
|
self.view = BasicView
|
||||||
connections.databases['default']['ATOMIC_REQUESTS'] = True
|
for database in connections.databases:
|
||||||
|
connections.databases[database]['ATOMIC_REQUESTS'] = True
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
connections.databases['default']['ATOMIC_REQUESTS'] = False
|
for database in connections.databases:
|
||||||
|
connections.databases[database]['ATOMIC_REQUESTS'] = False
|
||||||
|
|
||||||
def test_no_exception_commit_transaction(self):
|
def test_no_exception_commit_transaction(self):
|
||||||
request = factory.post('/')
|
request = factory.post('/')
|
||||||
|
|
||||||
with self.assertNumQueries(1):
|
with self.assertNumQueries(1):
|
||||||
response = self.view(request)
|
response = self.view.as_view()(request)
|
||||||
assert not transaction.get_rollback()
|
assert not transaction.get_rollback()
|
||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
assert BasicModel.objects.count() == 1
|
assert BasicModel.objects.count() == 1
|
||||||
|
|
||||||
|
def test_no_exception_commit_transaction_spare_connection(self):
|
||||||
|
request = factory.post('/')
|
||||||
|
|
||||||
|
with self.assertNumQueries(1, using='spare'):
|
||||||
|
view = self.view.as_view(database='spare')
|
||||||
|
response = view(request)
|
||||||
|
assert not transaction.get_rollback(using='spare')
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
assert BasicModel.objects.using('spare').count() == 1
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipUnless(
|
@unittest.skipUnless(
|
||||||
connection.features.uses_savepoints,
|
connection.features.uses_savepoints,
|
||||||
"'atomic' requires transactions and savepoints."
|
"'atomic' requires transactions and savepoints."
|
||||||
)
|
)
|
||||||
class DBTransactionErrorTests(TestCase):
|
class DBTransactionErrorTests(TestCase):
|
||||||
|
databases = '__all__'
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.view = ErrorView.as_view()
|
self.view = ErrorView
|
||||||
connections.databases['default']['ATOMIC_REQUESTS'] = True
|
for database in connections.databases:
|
||||||
|
connections.databases[database]['ATOMIC_REQUESTS'] = True
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
connections.databases['default']['ATOMIC_REQUESTS'] = False
|
for database in connections.databases:
|
||||||
|
connections.databases[database]['ATOMIC_REQUESTS'] = False
|
||||||
|
|
||||||
def test_generic_exception_delegate_transaction_management(self):
|
def test_generic_exception_delegate_transaction_management(self):
|
||||||
"""
|
"""
|
||||||
|
@ -95,22 +118,37 @@ class DBTransactionErrorTests(TestCase):
|
||||||
# 2 - insert
|
# 2 - insert
|
||||||
# 3 - release savepoint
|
# 3 - release savepoint
|
||||||
with transaction.atomic():
|
with transaction.atomic():
|
||||||
self.assertRaises(Exception, self.view, request)
|
self.assertRaises(Exception, self.view.as_view(), request)
|
||||||
assert not transaction.get_rollback()
|
assert not transaction.get_rollback()
|
||||||
assert BasicModel.objects.count() == 1
|
assert BasicModel.objects.count() == 1
|
||||||
|
|
||||||
|
def test_generic_exception_delegate_transaction_management_spare_connections(self):
|
||||||
|
request = factory.post('/')
|
||||||
|
with self.assertNumQueries(3, using='spare'):
|
||||||
|
# 1 - begin savepoint
|
||||||
|
# 2 - insert
|
||||||
|
# 3 - release savepoint
|
||||||
|
with transaction.atomic(using='spare'):
|
||||||
|
self.assertRaises(Exception, self.view.as_view(database='spare'), request)
|
||||||
|
assert not transaction.get_rollback(using='spare')
|
||||||
|
assert BasicModel.objects.using('spare').count() == 1
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipUnless(
|
@unittest.skipUnless(
|
||||||
connection.features.uses_savepoints,
|
connection.features.uses_savepoints,
|
||||||
"'atomic' requires transactions and savepoints."
|
"'atomic' requires transactions and savepoints."
|
||||||
)
|
)
|
||||||
class DBTransactionAPIExceptionTests(TestCase):
|
class DBTransactionAPIExceptionTests(TestCase):
|
||||||
|
databases = '__all__'
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.view = APIExceptionView.as_view()
|
self.view = APIExceptionView
|
||||||
connections.databases['default']['ATOMIC_REQUESTS'] = True
|
for database in connections.databases:
|
||||||
|
connections.databases[database]['ATOMIC_REQUESTS'] = True
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
connections.databases['default']['ATOMIC_REQUESTS'] = False
|
for database in connections.databases:
|
||||||
|
connections.databases[database]['ATOMIC_REQUESTS'] = False
|
||||||
|
|
||||||
def test_api_exception_rollback_transaction(self):
|
def test_api_exception_rollback_transaction(self):
|
||||||
"""
|
"""
|
||||||
|
@ -124,11 +162,28 @@ class DBTransactionAPIExceptionTests(TestCase):
|
||||||
# 3 - rollback savepoint
|
# 3 - rollback savepoint
|
||||||
# 4 - release savepoint
|
# 4 - release savepoint
|
||||||
with transaction.atomic():
|
with transaction.atomic():
|
||||||
response = self.view(request)
|
response = self.view.as_view()(request)
|
||||||
assert transaction.get_rollback()
|
assert transaction.get_rollback()
|
||||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||||
assert BasicModel.objects.count() == 0
|
assert BasicModel.objects.count() == 0
|
||||||
|
|
||||||
|
def test_api_exception_rollback_transaction_spare_connection(self):
|
||||||
|
"""
|
||||||
|
Transaction is rollbacked by our transaction atomic block.
|
||||||
|
"""
|
||||||
|
request = factory.post('/')
|
||||||
|
num_queries = 4 if connections['spare'].features.can_release_savepoints else 3
|
||||||
|
with self.assertNumQueries(num_queries, using='spare'):
|
||||||
|
# 1 - begin savepoint
|
||||||
|
# 2 - insert
|
||||||
|
# 3 - rollback savepoint
|
||||||
|
# 4 - release savepoint
|
||||||
|
with transaction.atomic(using='spare'):
|
||||||
|
response = self.view.as_view(database='spare')(request)
|
||||||
|
assert transaction.get_rollback(using='spare')
|
||||||
|
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||||
|
assert BasicModel.objects.using('spare').count() == 0
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipUnless(
|
@unittest.skipUnless(
|
||||||
connection.features.uses_savepoints,
|
connection.features.uses_savepoints,
|
||||||
|
@ -171,11 +226,15 @@ class MultiDBTransactionAPIExceptionTests(TestCase):
|
||||||
)
|
)
|
||||||
@override_settings(ROOT_URLCONF='tests.test_atomic_requests')
|
@override_settings(ROOT_URLCONF='tests.test_atomic_requests')
|
||||||
class NonAtomicDBTransactionAPIExceptionTests(TransactionTestCase):
|
class NonAtomicDBTransactionAPIExceptionTests(TransactionTestCase):
|
||||||
|
databases = '__all__'
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
connections.databases['default']['ATOMIC_REQUESTS'] = True
|
for database in connections.databases:
|
||||||
|
connections.databases[database]['ATOMIC_REQUESTS'] = True
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
connections.databases['default']['ATOMIC_REQUESTS'] = False
|
for database in connections.databases:
|
||||||
|
connections.databases[database]['ATOMIC_REQUESTS'] = False
|
||||||
|
|
||||||
def test_api_exception_rollback_transaction_non_atomic_view(self):
|
def test_api_exception_rollback_transaction_non_atomic_view(self):
|
||||||
response = self.client.get('/')
|
response = self.client.get('/')
|
||||||
|
|
Loading…
Reference in New Issue
Block a user