Implementing signals for CRUD operations in the generic endpoint

implementations, as well as updating the existing tests of the generics
in order to test the signals.
This commit is contained in:
Brian Balsamo 2018-08-27 00:55:18 -05:00
parent 6522d4ae20
commit f1442b2c24
3 changed files with 215 additions and 17 deletions

View File

@ -10,6 +10,10 @@ from django.shortcuts import get_object_or_404 as _get_object_or_404
from rest_framework import mixins, views from rest_framework import mixins, views
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.signals import (
post_create, post_destroy, post_read, post_update, pre_create, pre_destroy,
pre_read, pre_update
)
def get_object_or_404(queryset, *filter_args, **filter_kwargs): def get_object_or_404(queryset, *filter_args, **filter_kwargs):
@ -189,7 +193,10 @@ class CreateAPIView(mixins.CreateModelMixin,
Concrete view for creating a model instance. Concrete view for creating a model instance.
""" """
def post(self, request, *args, **kwargs): def post(self, request, *args, **kwargs):
return self.create(request, *args, **kwargs) pre_create.send(self, request=request)
resp = self.create(request, *args, **kwargs)
post_create.send(self, request=request, response=resp)
return resp
class ListAPIView(mixins.ListModelMixin, class ListAPIView(mixins.ListModelMixin,
@ -198,7 +205,10 @@ class ListAPIView(mixins.ListModelMixin,
Concrete view for listing a queryset. Concrete view for listing a queryset.
""" """
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
return self.list(request, *args, **kwargs) pre_read.send(self, request=request)
resp = self.list(request, *args, **kwargs)
post_read.send(self, request=request, response=resp)
return resp
class RetrieveAPIView(mixins.RetrieveModelMixin, class RetrieveAPIView(mixins.RetrieveModelMixin,
@ -207,7 +217,10 @@ class RetrieveAPIView(mixins.RetrieveModelMixin,
Concrete view for retrieving a model instance. Concrete view for retrieving a model instance.
""" """
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
return self.retrieve(request, *args, **kwargs) pre_read.send(self, request=request)
resp = self.retrieve(request, *args, **kwargs)
post_read.send(self, request=request, response=resp)
return resp
class DestroyAPIView(mixins.DestroyModelMixin, class DestroyAPIView(mixins.DestroyModelMixin,
@ -216,7 +229,10 @@ class DestroyAPIView(mixins.DestroyModelMixin,
Concrete view for deleting a model instance. Concrete view for deleting a model instance.
""" """
def delete(self, request, *args, **kwargs): def delete(self, request, *args, **kwargs):
return self.destroy(request, *args, **kwargs) pre_destroy.send(self, request=request)
resp = self.destroy(request, *args, **kwargs)
post_destroy.send(self, request=request, response=resp)
return resp
class UpdateAPIView(mixins.UpdateModelMixin, class UpdateAPIView(mixins.UpdateModelMixin,
@ -225,10 +241,16 @@ class UpdateAPIView(mixins.UpdateModelMixin,
Concrete view for updating a model instance. Concrete view for updating a model instance.
""" """
def put(self, request, *args, **kwargs): def put(self, request, *args, **kwargs):
return self.update(request, *args, **kwargs) pre_update.send(self, request=request)
resp = self.update(request, *args, **kwargs)
post_update.send(self, request=request, response=resp)
return resp
def patch(self, request, *args, **kwargs): def patch(self, request, *args, **kwargs):
return self.partial_update(request, *args, **kwargs) pre_update.send(self, request=request)
resp = self.partial_update(request, *args, **kwargs)
post_update.send(self, request=request, response=resp)
return resp
class ListCreateAPIView(mixins.ListModelMixin, class ListCreateAPIView(mixins.ListModelMixin,
@ -238,10 +260,16 @@ class ListCreateAPIView(mixins.ListModelMixin,
Concrete view for listing a queryset or creating a model instance. Concrete view for listing a queryset or creating a model instance.
""" """
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
return self.list(request, *args, **kwargs) pre_read.send(self, request=request)
resp = self.list(request, *args, **kwargs)
post_read.send(self, request=request, response=resp)
return resp
def post(self, request, *args, **kwargs): def post(self, request, *args, **kwargs):
return self.create(request, *args, **kwargs) pre_create.send(self, request=request)
resp = self.create(request, *args, **kwargs)
post_create.send(self, request=request, response=resp)
return resp
class RetrieveUpdateAPIView(mixins.RetrieveModelMixin, class RetrieveUpdateAPIView(mixins.RetrieveModelMixin,
@ -251,13 +279,22 @@ class RetrieveUpdateAPIView(mixins.RetrieveModelMixin,
Concrete view for retrieving, updating a model instance. Concrete view for retrieving, updating a model instance.
""" """
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
return self.retrieve(request, *args, **kwargs) pre_read.send(self, request=request)
resp = self.retrieve(request, *args, **kwargs)
post_read.send(self, request=request, response=resp)
return resp
def put(self, request, *args, **kwargs): def put(self, request, *args, **kwargs):
return self.update(request, *args, **kwargs) pre_update.send(self, request=request)
resp = self.update(request, *args, **kwargs)
post_update.send(self, request=request, response=resp)
return resp
def patch(self, request, *args, **kwargs): def patch(self, request, *args, **kwargs):
return self.partial_update(request, *args, **kwargs) pre_update.send(self, request=request)
resp = self.partial_update(request, *args, **kwargs)
post_update.send(self, request=request, response=resp)
return resp
class RetrieveDestroyAPIView(mixins.RetrieveModelMixin, class RetrieveDestroyAPIView(mixins.RetrieveModelMixin,
@ -267,10 +304,16 @@ class RetrieveDestroyAPIView(mixins.RetrieveModelMixin,
Concrete view for retrieving or deleting a model instance. Concrete view for retrieving or deleting a model instance.
""" """
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
return self.retrieve(request, *args, **kwargs) pre_read.send(self, request=request)
resp = self.retrieve(request, *args, **kwargs)
post_read.send(self, request=request, response=resp)
return resp
def delete(self, request, *args, **kwargs): def delete(self, request, *args, **kwargs):
return self.destroy(request, *args, **kwargs) pre_destroy.send(self, request=request)
resp = self.destroy(request, *args, **kwargs)
post_destroy.send(self, request=request, response=resp)
return resp
class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin, class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin,
@ -281,13 +324,25 @@ class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin,
Concrete view for retrieving, updating or deleting a model instance. Concrete view for retrieving, updating or deleting a model instance.
""" """
def get(self, request, *args, **kwargs): def get(self, request, *args, **kwargs):
return self.retrieve(request, *args, **kwargs) pre_read.send(self, request=request)
resp = self.retrieve(request, *args, **kwargs)
post_read.send(self, request=request, response=resp)
return resp
def put(self, request, *args, **kwargs): def put(self, request, *args, **kwargs):
return self.update(request, *args, **kwargs) pre_update.send(self, request=request)
resp = self.update(request, *args, **kwargs)
post_update.send(self, request=request, response=resp)
return resp
def patch(self, request, *args, **kwargs): def patch(self, request, *args, **kwargs):
return self.partial_update(request, *args, **kwargs) pre_update.send(self, request=request)
resp = self.partial_update(request, *args, **kwargs)
post_update.send(self, request=request, response=resp)
return resp
def delete(self, request, *args, **kwargs): def delete(self, request, *args, **kwargs):
return self.destroy(request, *args, **kwargs) pre_destroy.send(self, request=request)
resp = self.destroy(request, *args, **kwargs)
post_destroy.send(self, request=request, response=resp)
return resp

10
rest_framework/signals.py Normal file
View File

@ -0,0 +1,10 @@
from django.dispatch import Signal
pre_create = Signal(providing_args=['request'])
post_create = Signal(providing_args=['request', 'response'])
pre_read = Signal(providing_args=['request'])
post_read = Signal(providing_args=['request', 'response'])
pre_update = Signal(providing_args=['request'])
post_update = Signal(providing_args=['request', 'response'])
pre_destroy = Signal(providing_args=['request'])
post_destroy = Signal(providing_args=['request', 'response'])

View File

@ -1,5 +1,7 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from uuid import uuid4
import pytest import pytest
from django.db import models from django.db import models
from django.http import Http404 from django.http import Http404
@ -9,6 +11,10 @@ from django.utils import six
from rest_framework import generics, renderers, serializers, status from rest_framework import generics, renderers, serializers, status
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.signals import (
post_create, post_destroy, post_read, post_update, pre_create, pre_destroy,
pre_read, pre_update
)
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
from tests.models import ( from tests.models import (
BasicModel, ForeignKeySource, ForeignKeyTarget, RESTFrameworkModel, BasicModel, ForeignKeySource, ForeignKeyTarget, RESTFrameworkModel,
@ -17,6 +23,83 @@ from tests.models import (
factory = APIRequestFactory() factory = APIRequestFactory()
# Signals dict to catch signals emitted per test
signals = {
"pre_create": 0,
"post_create": 0,
"pre_read": 0,
"post_read": 0,
"pre_update": 0,
"post_update": 0,
"pre_destroy": 0,
"post_destroy": 0
}
# Signal handlers
def pre_create_handler(sender, **kwargs):
signals['pre_create'] += 1
def post_create_handler(sender, **kwargs):
signals['post_create'] += 1
def pre_read_handler(sender, **kwargs):
signals['pre_read'] += 1
def post_read_handler(sender, **kwargs):
signals['post_read'] += 1
def pre_update_handler(sender, **kwargs):
signals['pre_update'] += 1
def post_update_handler(sender, **kwargs):
signals['post_update'] += 1
def pre_destroy_handler(sender, **kwargs):
signals['pre_destroy'] += 1
def post_destroy_handler(sender, **kwargs):
signals['post_destroy'] += 1
# Connect signal handlers to the signals
pre_create.connect(pre_create_handler, dispatch_uid=uuid4().hex)
post_create.connect(post_create_handler, dispatch_uid=uuid4().hex)
pre_read.connect(pre_read_handler, dispatch_uid=uuid4().hex)
post_read.connect(post_read_handler, dispatch_uid=uuid4().hex)
pre_update.connect(pre_update_handler, dispatch_uid=uuid4().hex)
post_update.connect(post_update_handler, dispatch_uid=uuid4().hex)
pre_destroy.connect(pre_destroy_handler, dispatch_uid=uuid4().hex)
post_destroy.connect(post_destroy_handler, dispatch_uid=uuid4().hex)
# Resets all entries in the signal dict to 0 in test tearDowns
def reset_all_signals():
for x in signals.keys():
signals[x] = 0
# Test each signal has been called the correct number of times
def _test_all_signals(pre_create=0, post_create=0,
pre_read=0, post_read=0,
pre_update=0, post_update=0,
pre_destroy=0, post_destroy=0):
assert signals['pre_create'] == pre_create
assert signals['post_create'] == post_create
assert signals['pre_read'] == pre_read
assert signals['post_read'] == post_read
assert signals['pre_update'] == pre_update
assert signals['post_update'] == post_update
assert signals['pre_destroy'] == pre_destroy
assert signals['post_destroy'] == post_destroy
# Models # Models
class SlugBasedModel(RESTFrameworkModel): class SlugBasedModel(RESTFrameworkModel):
@ -92,6 +175,7 @@ class TestRootView(TestCase):
for obj in self.objects.all() for obj in self.objects.all()
] ]
self.view = RootView.as_view() self.view = RootView.as_view()
reset_all_signals()
def test_get_root_view(self): def test_get_root_view(self):
""" """
@ -100,6 +184,7 @@ class TestRootView(TestCase):
request = factory.get('/') request = factory.get('/')
with self.assertNumQueries(1): with self.assertNumQueries(1):
response = self.view(request).render() response = self.view(request).render()
_test_all_signals(pre_read=1, post_read=1)
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.data == self.data assert response.data == self.data
@ -110,6 +195,10 @@ class TestRootView(TestCase):
request = factory.head('/') request = factory.head('/')
with self.assertNumQueries(1): with self.assertNumQueries(1):
response = self.view(request).render() response = self.view(request).render()
# See here:
# https://github.com/django/django/blob/master/django/views/generic/base.py#L63
# For why this triggers get triggers
_test_all_signals(pre_read=1, post_read=1)
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
def test_post_root_view(self): def test_post_root_view(self):
@ -120,6 +209,7 @@ class TestRootView(TestCase):
request = factory.post('/', data, format='json') request = factory.post('/', data, format='json')
with self.assertNumQueries(1): with self.assertNumQueries(1):
response = self.view(request).render() response = self.view(request).render()
_test_all_signals(pre_create=1, post_create=1)
assert response.status_code == status.HTTP_201_CREATED assert response.status_code == status.HTTP_201_CREATED
assert response.data == {'id': 4, 'text': 'foobar'} assert response.data == {'id': 4, 'text': 'foobar'}
created = self.objects.get(id=4) created = self.objects.get(id=4)
@ -133,6 +223,7 @@ class TestRootView(TestCase):
request = factory.put('/', data, format='json') request = factory.put('/', data, format='json')
with self.assertNumQueries(0): with self.assertNumQueries(0):
response = self.view(request).render() response = self.view(request).render()
_test_all_signals()
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
assert response.data == {"detail": 'Method "PUT" not allowed.'} assert response.data == {"detail": 'Method "PUT" not allowed.'}
@ -143,6 +234,7 @@ class TestRootView(TestCase):
request = factory.delete('/') request = factory.delete('/')
with self.assertNumQueries(0): with self.assertNumQueries(0):
response = self.view(request).render() response = self.view(request).render()
_test_all_signals()
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
assert response.data == {"detail": 'Method "DELETE" not allowed.'} assert response.data == {"detail": 'Method "DELETE" not allowed.'}
@ -154,6 +246,7 @@ class TestRootView(TestCase):
request = factory.post('/', data, format='json') request = factory.post('/', data, format='json')
with self.assertNumQueries(1): with self.assertNumQueries(1):
response = self.view(request).render() response = self.view(request).render()
_test_all_signals(pre_create=1, post_create=1)
assert response.status_code == status.HTTP_201_CREATED assert response.status_code == status.HTTP_201_CREATED
assert response.data == {'id': 4, 'text': 'foobar'} assert response.data == {'id': 4, 'text': 'foobar'}
created = self.objects.get(id=4) created = self.objects.get(id=4)
@ -168,6 +261,7 @@ class TestRootView(TestCase):
response = self.view(request).render() response = self.view(request).render()
expected_error = '<span class="help-block">Ensure this field has no more than 100 characters.</span>' expected_error = '<span class="help-block">Ensure this field has no more than 100 characters.</span>'
assert expected_error in response.rendered_content.decode('utf-8') assert expected_error in response.rendered_content.decode('utf-8')
_test_all_signals(pre_create=1)
EXPECTED_QUERIES_FOR_PUT = 2 EXPECTED_QUERIES_FOR_PUT = 2
@ -188,6 +282,7 @@ class TestInstanceView(TestCase):
] ]
self.view = InstanceView.as_view() self.view = InstanceView.as_view()
self.slug_based_view = SlugBasedInstanceView.as_view() self.slug_based_view = SlugBasedInstanceView.as_view()
reset_all_signals()
def test_get_instance_view(self): def test_get_instance_view(self):
""" """
@ -196,6 +291,7 @@ class TestInstanceView(TestCase):
request = factory.get('/1') request = factory.get('/1')
with self.assertNumQueries(1): with self.assertNumQueries(1):
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
_test_all_signals(pre_read=1, post_read=1)
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.data == self.data[0] assert response.data == self.data[0]
@ -207,6 +303,7 @@ class TestInstanceView(TestCase):
request = factory.post('/', data, format='json') request = factory.post('/', data, format='json')
with self.assertNumQueries(0): with self.assertNumQueries(0):
response = self.view(request).render() response = self.view(request).render()
_test_all_signals()
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
assert response.data == {"detail": 'Method "POST" not allowed.'} assert response.data == {"detail": 'Method "POST" not allowed.'}
@ -218,6 +315,7 @@ class TestInstanceView(TestCase):
request = factory.put('/1', data, format='json') request = factory.put('/1', data, format='json')
with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT): with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT):
response = self.view(request, pk='1').render() response = self.view(request, pk='1').render()
_test_all_signals(pre_update=1, post_update=1)
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert dict(response.data) == {'id': 1, 'text': 'foobar'} assert dict(response.data) == {'id': 1, 'text': 'foobar'}
updated = self.objects.get(id=1) updated = self.objects.get(id=1)
@ -232,6 +330,7 @@ class TestInstanceView(TestCase):
with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT): with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT):
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
_test_all_signals(pre_update=1, post_update=1)
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.data == {'id': 1, 'text': 'foobar'} assert response.data == {'id': 1, 'text': 'foobar'}
updated = self.objects.get(id=1) updated = self.objects.get(id=1)
@ -244,6 +343,7 @@ class TestInstanceView(TestCase):
request = factory.delete('/1') request = factory.delete('/1')
with self.assertNumQueries(2): with self.assertNumQueries(2):
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
_test_all_signals(pre_destroy=1, post_destroy=1)
assert response.status_code == status.HTTP_204_NO_CONTENT assert response.status_code == status.HTTP_204_NO_CONTENT
assert response.content == six.b('') assert response.content == six.b('')
ids = [obj.id for obj in self.objects.all()] ids = [obj.id for obj in self.objects.all()]
@ -257,6 +357,7 @@ class TestInstanceView(TestCase):
request = factory.get('/a') request = factory.get('/a')
with self.assertNumQueries(0): with self.assertNumQueries(0):
response = self.view(request, pk='a').render() response = self.view(request, pk='a').render()
_test_all_signals(pre_read=1)
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_404_NOT_FOUND
def test_put_cannot_set_id(self): def test_put_cannot_set_id(self):
@ -282,6 +383,7 @@ class TestInstanceView(TestCase):
request = factory.put('/1', data, format='json') request = factory.put('/1', data, format='json')
with self.assertNumQueries(1): with self.assertNumQueries(1):
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
_test_all_signals(pre_update=1)
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_404_NOT_FOUND
def test_put_to_filtered_out_instance(self): def test_put_to_filtered_out_instance(self):
@ -294,6 +396,7 @@ class TestInstanceView(TestCase):
request = factory.put('/{0}'.format(filtered_out_pk), data, format='json') request = factory.put('/{0}'.format(filtered_out_pk), data, format='json')
response = self.view(request, pk=filtered_out_pk).render() response = self.view(request, pk=filtered_out_pk).render()
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_404_NOT_FOUND
_test_all_signals(pre_update=1)
def test_patch_cannot_create_an_object(self): def test_patch_cannot_create_an_object(self):
""" """
@ -305,6 +408,7 @@ class TestInstanceView(TestCase):
response = self.view(request, pk=999).render() response = self.view(request, pk=999).render()
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_404_NOT_FOUND
assert not self.objects.filter(id=999).exists() assert not self.objects.filter(id=999).exists()
_test_all_signals(pre_update=1)
def test_put_error_instance_view(self): def test_put_error_instance_view(self):
""" """
@ -315,6 +419,7 @@ class TestInstanceView(TestCase):
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
expected_error = '<span class="help-block">Ensure this field has no more than 100 characters.</span>' expected_error = '<span class="help-block">Ensure this field has no more than 100 characters.</span>'
assert expected_error in response.rendered_content.decode('utf-8') assert expected_error in response.rendered_content.decode('utf-8')
_test_all_signals(pre_update=1)
class TestFKInstanceView(TestCase): class TestFKInstanceView(TestCase):
@ -366,6 +471,7 @@ class TestOverriddenGetObject(TestCase):
return get_object_or_404(BasicModel.objects.all(), id=pk) return get_object_or_404(BasicModel.objects.all(), id=pk)
self.view = OverriddenGetObjectView.as_view() self.view = OverriddenGetObjectView.as_view()
reset_all_signals()
def test_overridden_get_object_view(self): def test_overridden_get_object_view(self):
""" """
@ -374,6 +480,7 @@ class TestOverriddenGetObject(TestCase):
request = factory.get('/1') request = factory.get('/1')
with self.assertNumQueries(1): with self.assertNumQueries(1):
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
_test_all_signals(pre_read=1, post_read=1)
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.data == self.data[0] assert response.data == self.data[0]
@ -395,6 +502,7 @@ class TestCreateModelWithAutoNowAddField(TestCase):
def setUp(self): def setUp(self):
self.objects = Comment.objects self.objects = Comment.objects
self.view = CommentView.as_view() self.view = CommentView.as_view()
reset_all_signals()
def test_create_model_with_auto_now_add_field(self): def test_create_model_with_auto_now_add_field(self):
""" """
@ -408,6 +516,7 @@ class TestCreateModelWithAutoNowAddField(TestCase):
assert response.status_code == status.HTTP_201_CREATED assert response.status_code == status.HTTP_201_CREATED
created = self.objects.get(id=1) created = self.objects.get(id=1)
assert created.content == 'foobar' assert created.content == 'foobar'
_test_all_signals(pre_create=1, post_create=1)
# Test for particularly ugly regression with m2m in browsable API # Test for particularly ugly regression with m2m in browsable API
@ -436,6 +545,9 @@ class ExampleView(generics.ListCreateAPIView):
class TestM2MBrowsableAPI(TestCase): class TestM2MBrowsableAPI(TestCase):
def setUp(self):
reset_all_signals()
def test_m2m_in_browsable_api(self): def test_m2m_in_browsable_api(self):
""" """
Test for particularly ugly regression with m2m in browsable API Test for particularly ugly regression with m2m in browsable API
@ -444,6 +556,7 @@ class TestM2MBrowsableAPI(TestCase):
view = ExampleView().as_view() view = ExampleView().as_view()
response = view(request).render() response = view(request).render()
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
_test_all_signals(pre_read=1, post_read=1)
class InclusiveFilterBackend(object): class InclusiveFilterBackend(object):
@ -492,6 +605,7 @@ class TestFilterBackendAppliedToViews(TestCase):
{'id': obj.id, 'text': obj.text} {'id': obj.id, 'text': obj.text}
for obj in self.objects.all() for obj in self.objects.all()
] ]
reset_all_signals()
def test_get_root_view_filters_by_name_with_filter_backend(self): def test_get_root_view_filters_by_name_with_filter_backend(self):
""" """
@ -503,6 +617,7 @@ class TestFilterBackendAppliedToViews(TestCase):
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert len(response.data) == 1 assert len(response.data) == 1
assert response.data == [{'id': 1, 'text': 'foo'}] assert response.data == [{'id': 1, 'text': 'foo'}]
_test_all_signals(pre_read=1, post_read=1)
def test_get_root_view_filters_out_all_models_with_exclusive_filter_backend(self): def test_get_root_view_filters_out_all_models_with_exclusive_filter_backend(self):
""" """
@ -513,6 +628,7 @@ class TestFilterBackendAppliedToViews(TestCase):
response = root_view(request).render() response = root_view(request).render()
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.data == [] assert response.data == []
_test_all_signals(pre_read=1, post_read=1)
def test_get_instance_view_filters_out_name_with_filter_backend(self): def test_get_instance_view_filters_out_name_with_filter_backend(self):
""" """
@ -523,6 +639,7 @@ class TestFilterBackendAppliedToViews(TestCase):
response = instance_view(request, pk=1).render() response = instance_view(request, pk=1).render()
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_404_NOT_FOUND
assert response.data == {'detail': 'Not found.'} assert response.data == {'detail': 'Not found.'}
_test_all_signals(pre_read=1)
def test_get_instance_view_will_return_single_object_when_filter_does_not_exclude_it(self): def test_get_instance_view_will_return_single_object_when_filter_does_not_exclude_it(self):
""" """
@ -533,6 +650,7 @@ class TestFilterBackendAppliedToViews(TestCase):
response = instance_view(request, pk=1).render() response = instance_view(request, pk=1).render()
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.data == {'id': 1, 'text': 'foo'} assert response.data == {'id': 1, 'text': 'foo'}
_test_all_signals(pre_read=1, post_read=1)
def test_dynamic_serializer_form_in_browsable_api(self): def test_dynamic_serializer_form_in_browsable_api(self):
""" """
@ -544,9 +662,13 @@ class TestFilterBackendAppliedToViews(TestCase):
content = response.content.decode('utf8') content = response.content.decode('utf8')
assert 'field_b' in content assert 'field_b' in content
assert 'field_a' not in content assert 'field_a' not in content
_test_all_signals(pre_read=1, post_read=1)
class TestGuardedQueryset(TestCase): class TestGuardedQueryset(TestCase):
def setUp(self):
reset_all_signals()
def test_guarded_queryset(self): def test_guarded_queryset(self):
class QuerysetAccessError(generics.ListAPIView): class QuerysetAccessError(generics.ListAPIView):
queryset = BasicModel.objects.all() queryset = BasicModel.objects.all()
@ -556,11 +678,14 @@ class TestGuardedQueryset(TestCase):
view = QuerysetAccessError.as_view() view = QuerysetAccessError.as_view()
request = factory.get('/') request = factory.get('/')
_test_all_signals()
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
view(request).render() view(request).render()
class ApiViewsTests(TestCase): class ApiViewsTests(TestCase):
def setUp(self):
reset_all_signals()
def test_create_api_view_post(self): def test_create_api_view_post(self):
class MockCreateApiView(generics.CreateAPIView): class MockCreateApiView(generics.CreateAPIView):
@ -572,6 +697,7 @@ class ApiViewsTests(TestCase):
view.post('test request', 'test arg', test_kwarg='test') view.post('test request', 'test arg', test_kwarg='test')
assert view.called is True assert view.called is True
assert view.call_args == data assert view.call_args == data
_test_all_signals(pre_create=1, post_create=1)
def test_destroy_api_view_delete(self): def test_destroy_api_view_delete(self):
class MockDestroyApiView(generics.DestroyAPIView): class MockDestroyApiView(generics.DestroyAPIView):
@ -583,6 +709,7 @@ class ApiViewsTests(TestCase):
view.delete('test request', 'test arg', test_kwarg='test') view.delete('test request', 'test arg', test_kwarg='test')
assert view.called is True assert view.called is True
assert view.call_args == data assert view.call_args == data
_test_all_signals(pre_destroy=1, post_destroy=1)
def test_update_api_view_partial_update(self): def test_update_api_view_partial_update(self):
class MockUpdateApiView(generics.UpdateAPIView): class MockUpdateApiView(generics.UpdateAPIView):
@ -594,6 +721,7 @@ class ApiViewsTests(TestCase):
view.patch('test request', 'test arg', test_kwarg='test') view.patch('test request', 'test arg', test_kwarg='test')
assert view.called is True assert view.called is True
assert view.call_args == data assert view.call_args == data
_test_all_signals(pre_update=1, post_update=1)
def test_retrieve_update_api_view_get(self): def test_retrieve_update_api_view_get(self):
class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView): class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView):
@ -605,6 +733,7 @@ class ApiViewsTests(TestCase):
view.get('test request', 'test arg', test_kwarg='test') view.get('test request', 'test arg', test_kwarg='test')
assert view.called is True assert view.called is True
assert view.call_args == data assert view.call_args == data
_test_all_signals(pre_read=1, post_read=1)
def test_retrieve_update_api_view_put(self): def test_retrieve_update_api_view_put(self):
class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView): class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView):
@ -616,6 +745,7 @@ class ApiViewsTests(TestCase):
view.put('test request', 'test arg', test_kwarg='test') view.put('test request', 'test arg', test_kwarg='test')
assert view.called is True assert view.called is True
assert view.call_args == data assert view.call_args == data
_test_all_signals(pre_update=1, post_update=1)
def test_retrieve_update_api_view_patch(self): def test_retrieve_update_api_view_patch(self):
class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView): class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView):
@ -627,6 +757,7 @@ class ApiViewsTests(TestCase):
view.patch('test request', 'test arg', test_kwarg='test') view.patch('test request', 'test arg', test_kwarg='test')
assert view.called is True assert view.called is True
assert view.call_args == data assert view.call_args == data
_test_all_signals(pre_update=1, post_update=1)
def test_retrieve_destroy_api_view_get(self): def test_retrieve_destroy_api_view_get(self):
class MockRetrieveDestroyUApiView(generics.RetrieveDestroyAPIView): class MockRetrieveDestroyUApiView(generics.RetrieveDestroyAPIView):
@ -638,6 +769,7 @@ class ApiViewsTests(TestCase):
view.get('test request', 'test arg', test_kwarg='test') view.get('test request', 'test arg', test_kwarg='test')
assert view.called is True assert view.called is True
assert view.call_args == data assert view.call_args == data
_test_all_signals(pre_read=1, post_read=1)
def test_retrieve_destroy_api_view_delete(self): def test_retrieve_destroy_api_view_delete(self):
class MockRetrieveDestroyUApiView(generics.RetrieveDestroyAPIView): class MockRetrieveDestroyUApiView(generics.RetrieveDestroyAPIView):
@ -649,6 +781,7 @@ class ApiViewsTests(TestCase):
view.delete('test request', 'test arg', test_kwarg='test') view.delete('test request', 'test arg', test_kwarg='test')
assert view.called is True assert view.called is True
assert view.call_args == data assert view.call_args == data
_test_all_signals(pre_destroy=1, post_destroy=1)
class GetObjectOr404Tests(TestCase): class GetObjectOr404Tests(TestCase):