Implementing signals for CRUD operations in the generic endpoint

implementations, as well as updating the existing tests of the generics
in order to test the signals.
This commit is contained in:
Brian Balsamo 2018-08-27 00:55:18 -05:00
parent 6522d4ae20
commit f1442b2c24
3 changed files with 215 additions and 17 deletions

View File

@ -10,6 +10,10 @@ from django.shortcuts import get_object_or_404 as _get_object_or_404
from rest_framework import mixins, views
from rest_framework.settings import api_settings
from rest_framework.signals import (
post_create, post_destroy, post_read, post_update, pre_create, pre_destroy,
pre_read, pre_update
)
def get_object_or_404(queryset, *filter_args, **filter_kwargs):
@ -189,7 +193,10 @@ class CreateAPIView(mixins.CreateModelMixin,
Concrete view for creating a model instance.
"""
def post(self, request, *args, **kwargs):
return self.create(request, *args, **kwargs)
pre_create.send(self, request=request)
resp = self.create(request, *args, **kwargs)
post_create.send(self, request=request, response=resp)
return resp
class ListAPIView(mixins.ListModelMixin,
@ -198,7 +205,10 @@ class ListAPIView(mixins.ListModelMixin,
Concrete view for listing a queryset.
"""
def get(self, request, *args, **kwargs):
return self.list(request, *args, **kwargs)
pre_read.send(self, request=request)
resp = self.list(request, *args, **kwargs)
post_read.send(self, request=request, response=resp)
return resp
class RetrieveAPIView(mixins.RetrieveModelMixin,
@ -207,7 +217,10 @@ class RetrieveAPIView(mixins.RetrieveModelMixin,
Concrete view for retrieving a model instance.
"""
def get(self, request, *args, **kwargs):
return self.retrieve(request, *args, **kwargs)
pre_read.send(self, request=request)
resp = self.retrieve(request, *args, **kwargs)
post_read.send(self, request=request, response=resp)
return resp
class DestroyAPIView(mixins.DestroyModelMixin,
@ -216,7 +229,10 @@ class DestroyAPIView(mixins.DestroyModelMixin,
Concrete view for deleting a model instance.
"""
def delete(self, request, *args, **kwargs):
return self.destroy(request, *args, **kwargs)
pre_destroy.send(self, request=request)
resp = self.destroy(request, *args, **kwargs)
post_destroy.send(self, request=request, response=resp)
return resp
class UpdateAPIView(mixins.UpdateModelMixin,
@ -225,10 +241,16 @@ class UpdateAPIView(mixins.UpdateModelMixin,
Concrete view for updating a model instance.
"""
def put(self, request, *args, **kwargs):
return self.update(request, *args, **kwargs)
pre_update.send(self, request=request)
resp = self.update(request, *args, **kwargs)
post_update.send(self, request=request, response=resp)
return resp
def patch(self, request, *args, **kwargs):
return self.partial_update(request, *args, **kwargs)
pre_update.send(self, request=request)
resp = self.partial_update(request, *args, **kwargs)
post_update.send(self, request=request, response=resp)
return resp
class ListCreateAPIView(mixins.ListModelMixin,
@ -238,10 +260,16 @@ class ListCreateAPIView(mixins.ListModelMixin,
Concrete view for listing a queryset or creating a model instance.
"""
def get(self, request, *args, **kwargs):
return self.list(request, *args, **kwargs)
pre_read.send(self, request=request)
resp = self.list(request, *args, **kwargs)
post_read.send(self, request=request, response=resp)
return resp
def post(self, request, *args, **kwargs):
return self.create(request, *args, **kwargs)
pre_create.send(self, request=request)
resp = self.create(request, *args, **kwargs)
post_create.send(self, request=request, response=resp)
return resp
class RetrieveUpdateAPIView(mixins.RetrieveModelMixin,
@ -251,13 +279,22 @@ class RetrieveUpdateAPIView(mixins.RetrieveModelMixin,
Concrete view for retrieving, updating a model instance.
"""
def get(self, request, *args, **kwargs):
return self.retrieve(request, *args, **kwargs)
pre_read.send(self, request=request)
resp = self.retrieve(request, *args, **kwargs)
post_read.send(self, request=request, response=resp)
return resp
def put(self, request, *args, **kwargs):
return self.update(request, *args, **kwargs)
pre_update.send(self, request=request)
resp = self.update(request, *args, **kwargs)
post_update.send(self, request=request, response=resp)
return resp
def patch(self, request, *args, **kwargs):
return self.partial_update(request, *args, **kwargs)
pre_update.send(self, request=request)
resp = self.partial_update(request, *args, **kwargs)
post_update.send(self, request=request, response=resp)
return resp
class RetrieveDestroyAPIView(mixins.RetrieveModelMixin,
@ -267,10 +304,16 @@ class RetrieveDestroyAPIView(mixins.RetrieveModelMixin,
Concrete view for retrieving or deleting a model instance.
"""
def get(self, request, *args, **kwargs):
return self.retrieve(request, *args, **kwargs)
pre_read.send(self, request=request)
resp = self.retrieve(request, *args, **kwargs)
post_read.send(self, request=request, response=resp)
return resp
def delete(self, request, *args, **kwargs):
return self.destroy(request, *args, **kwargs)
pre_destroy.send(self, request=request)
resp = self.destroy(request, *args, **kwargs)
post_destroy.send(self, request=request, response=resp)
return resp
class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin,
@ -281,13 +324,25 @@ class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin,
Concrete view for retrieving, updating or deleting a model instance.
"""
def get(self, request, *args, **kwargs):
return self.retrieve(request, *args, **kwargs)
pre_read.send(self, request=request)
resp = self.retrieve(request, *args, **kwargs)
post_read.send(self, request=request, response=resp)
return resp
def put(self, request, *args, **kwargs):
return self.update(request, *args, **kwargs)
pre_update.send(self, request=request)
resp = self.update(request, *args, **kwargs)
post_update.send(self, request=request, response=resp)
return resp
def patch(self, request, *args, **kwargs):
return self.partial_update(request, *args, **kwargs)
pre_update.send(self, request=request)
resp = self.partial_update(request, *args, **kwargs)
post_update.send(self, request=request, response=resp)
return resp
def delete(self, request, *args, **kwargs):
return self.destroy(request, *args, **kwargs)
pre_destroy.send(self, request=request)
resp = self.destroy(request, *args, **kwargs)
post_destroy.send(self, request=request, response=resp)
return resp

10
rest_framework/signals.py Normal file
View File

@ -0,0 +1,10 @@
from django.dispatch import Signal
pre_create = Signal(providing_args=['request'])
post_create = Signal(providing_args=['request', 'response'])
pre_read = Signal(providing_args=['request'])
post_read = Signal(providing_args=['request', 'response'])
pre_update = Signal(providing_args=['request'])
post_update = Signal(providing_args=['request', 'response'])
pre_destroy = Signal(providing_args=['request'])
post_destroy = Signal(providing_args=['request', 'response'])

View File

@ -1,5 +1,7 @@
from __future__ import unicode_literals
from uuid import uuid4
import pytest
from django.db import models
from django.http import Http404
@ -9,6 +11,10 @@ from django.utils import six
from rest_framework import generics, renderers, serializers, status
from rest_framework.response import Response
from rest_framework.signals import (
post_create, post_destroy, post_read, post_update, pre_create, pre_destroy,
pre_read, pre_update
)
from rest_framework.test import APIRequestFactory
from tests.models import (
BasicModel, ForeignKeySource, ForeignKeyTarget, RESTFrameworkModel,
@ -17,6 +23,83 @@ from tests.models import (
factory = APIRequestFactory()
# Signals dict to catch signals emitted per test
signals = {
"pre_create": 0,
"post_create": 0,
"pre_read": 0,
"post_read": 0,
"pre_update": 0,
"post_update": 0,
"pre_destroy": 0,
"post_destroy": 0
}
# Signal handlers
def pre_create_handler(sender, **kwargs):
signals['pre_create'] += 1
def post_create_handler(sender, **kwargs):
signals['post_create'] += 1
def pre_read_handler(sender, **kwargs):
signals['pre_read'] += 1
def post_read_handler(sender, **kwargs):
signals['post_read'] += 1
def pre_update_handler(sender, **kwargs):
signals['pre_update'] += 1
def post_update_handler(sender, **kwargs):
signals['post_update'] += 1
def pre_destroy_handler(sender, **kwargs):
signals['pre_destroy'] += 1
def post_destroy_handler(sender, **kwargs):
signals['post_destroy'] += 1
# Connect signal handlers to the signals
pre_create.connect(pre_create_handler, dispatch_uid=uuid4().hex)
post_create.connect(post_create_handler, dispatch_uid=uuid4().hex)
pre_read.connect(pre_read_handler, dispatch_uid=uuid4().hex)
post_read.connect(post_read_handler, dispatch_uid=uuid4().hex)
pre_update.connect(pre_update_handler, dispatch_uid=uuid4().hex)
post_update.connect(post_update_handler, dispatch_uid=uuid4().hex)
pre_destroy.connect(pre_destroy_handler, dispatch_uid=uuid4().hex)
post_destroy.connect(post_destroy_handler, dispatch_uid=uuid4().hex)
# Resets all entries in the signal dict to 0 in test tearDowns
def reset_all_signals():
for x in signals.keys():
signals[x] = 0
# Test each signal has been called the correct number of times
def _test_all_signals(pre_create=0, post_create=0,
pre_read=0, post_read=0,
pre_update=0, post_update=0,
pre_destroy=0, post_destroy=0):
assert signals['pre_create'] == pre_create
assert signals['post_create'] == post_create
assert signals['pre_read'] == pre_read
assert signals['post_read'] == post_read
assert signals['pre_update'] == pre_update
assert signals['post_update'] == post_update
assert signals['pre_destroy'] == pre_destroy
assert signals['post_destroy'] == post_destroy
# Models
class SlugBasedModel(RESTFrameworkModel):
@ -92,6 +175,7 @@ class TestRootView(TestCase):
for obj in self.objects.all()
]
self.view = RootView.as_view()
reset_all_signals()
def test_get_root_view(self):
"""
@ -100,6 +184,7 @@ class TestRootView(TestCase):
request = factory.get('/')
with self.assertNumQueries(1):
response = self.view(request).render()
_test_all_signals(pre_read=1, post_read=1)
assert response.status_code == status.HTTP_200_OK
assert response.data == self.data
@ -110,6 +195,10 @@ class TestRootView(TestCase):
request = factory.head('/')
with self.assertNumQueries(1):
response = self.view(request).render()
# See here:
# https://github.com/django/django/blob/master/django/views/generic/base.py#L63
# For why this triggers get triggers
_test_all_signals(pre_read=1, post_read=1)
assert response.status_code == status.HTTP_200_OK
def test_post_root_view(self):
@ -120,6 +209,7 @@ class TestRootView(TestCase):
request = factory.post('/', data, format='json')
with self.assertNumQueries(1):
response = self.view(request).render()
_test_all_signals(pre_create=1, post_create=1)
assert response.status_code == status.HTTP_201_CREATED
assert response.data == {'id': 4, 'text': 'foobar'}
created = self.objects.get(id=4)
@ -133,6 +223,7 @@ class TestRootView(TestCase):
request = factory.put('/', data, format='json')
with self.assertNumQueries(0):
response = self.view(request).render()
_test_all_signals()
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
assert response.data == {"detail": 'Method "PUT" not allowed.'}
@ -143,6 +234,7 @@ class TestRootView(TestCase):
request = factory.delete('/')
with self.assertNumQueries(0):
response = self.view(request).render()
_test_all_signals()
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
assert response.data == {"detail": 'Method "DELETE" not allowed.'}
@ -154,6 +246,7 @@ class TestRootView(TestCase):
request = factory.post('/', data, format='json')
with self.assertNumQueries(1):
response = self.view(request).render()
_test_all_signals(pre_create=1, post_create=1)
assert response.status_code == status.HTTP_201_CREATED
assert response.data == {'id': 4, 'text': 'foobar'}
created = self.objects.get(id=4)
@ -168,6 +261,7 @@ class TestRootView(TestCase):
response = self.view(request).render()
expected_error = '<span class="help-block">Ensure this field has no more than 100 characters.</span>'
assert expected_error in response.rendered_content.decode('utf-8')
_test_all_signals(pre_create=1)
EXPECTED_QUERIES_FOR_PUT = 2
@ -188,6 +282,7 @@ class TestInstanceView(TestCase):
]
self.view = InstanceView.as_view()
self.slug_based_view = SlugBasedInstanceView.as_view()
reset_all_signals()
def test_get_instance_view(self):
"""
@ -196,6 +291,7 @@ class TestInstanceView(TestCase):
request = factory.get('/1')
with self.assertNumQueries(1):
response = self.view(request, pk=1).render()
_test_all_signals(pre_read=1, post_read=1)
assert response.status_code == status.HTTP_200_OK
assert response.data == self.data[0]
@ -207,6 +303,7 @@ class TestInstanceView(TestCase):
request = factory.post('/', data, format='json')
with self.assertNumQueries(0):
response = self.view(request).render()
_test_all_signals()
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
assert response.data == {"detail": 'Method "POST" not allowed.'}
@ -218,6 +315,7 @@ class TestInstanceView(TestCase):
request = factory.put('/1', data, format='json')
with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT):
response = self.view(request, pk='1').render()
_test_all_signals(pre_update=1, post_update=1)
assert response.status_code == status.HTTP_200_OK
assert dict(response.data) == {'id': 1, 'text': 'foobar'}
updated = self.objects.get(id=1)
@ -232,6 +330,7 @@ class TestInstanceView(TestCase):
with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT):
response = self.view(request, pk=1).render()
_test_all_signals(pre_update=1, post_update=1)
assert response.status_code == status.HTTP_200_OK
assert response.data == {'id': 1, 'text': 'foobar'}
updated = self.objects.get(id=1)
@ -244,6 +343,7 @@ class TestInstanceView(TestCase):
request = factory.delete('/1')
with self.assertNumQueries(2):
response = self.view(request, pk=1).render()
_test_all_signals(pre_destroy=1, post_destroy=1)
assert response.status_code == status.HTTP_204_NO_CONTENT
assert response.content == six.b('')
ids = [obj.id for obj in self.objects.all()]
@ -257,6 +357,7 @@ class TestInstanceView(TestCase):
request = factory.get('/a')
with self.assertNumQueries(0):
response = self.view(request, pk='a').render()
_test_all_signals(pre_read=1)
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_put_cannot_set_id(self):
@ -282,6 +383,7 @@ class TestInstanceView(TestCase):
request = factory.put('/1', data, format='json')
with self.assertNumQueries(1):
response = self.view(request, pk=1).render()
_test_all_signals(pre_update=1)
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_put_to_filtered_out_instance(self):
@ -294,6 +396,7 @@ class TestInstanceView(TestCase):
request = factory.put('/{0}'.format(filtered_out_pk), data, format='json')
response = self.view(request, pk=filtered_out_pk).render()
assert response.status_code == status.HTTP_404_NOT_FOUND
_test_all_signals(pre_update=1)
def test_patch_cannot_create_an_object(self):
"""
@ -305,6 +408,7 @@ class TestInstanceView(TestCase):
response = self.view(request, pk=999).render()
assert response.status_code == status.HTTP_404_NOT_FOUND
assert not self.objects.filter(id=999).exists()
_test_all_signals(pre_update=1)
def test_put_error_instance_view(self):
"""
@ -315,6 +419,7 @@ class TestInstanceView(TestCase):
response = self.view(request, pk=1).render()
expected_error = '<span class="help-block">Ensure this field has no more than 100 characters.</span>'
assert expected_error in response.rendered_content.decode('utf-8')
_test_all_signals(pre_update=1)
class TestFKInstanceView(TestCase):
@ -366,6 +471,7 @@ class TestOverriddenGetObject(TestCase):
return get_object_or_404(BasicModel.objects.all(), id=pk)
self.view = OverriddenGetObjectView.as_view()
reset_all_signals()
def test_overridden_get_object_view(self):
"""
@ -374,6 +480,7 @@ class TestOverriddenGetObject(TestCase):
request = factory.get('/1')
with self.assertNumQueries(1):
response = self.view(request, pk=1).render()
_test_all_signals(pre_read=1, post_read=1)
assert response.status_code == status.HTTP_200_OK
assert response.data == self.data[0]
@ -395,6 +502,7 @@ class TestCreateModelWithAutoNowAddField(TestCase):
def setUp(self):
self.objects = Comment.objects
self.view = CommentView.as_view()
reset_all_signals()
def test_create_model_with_auto_now_add_field(self):
"""
@ -408,6 +516,7 @@ class TestCreateModelWithAutoNowAddField(TestCase):
assert response.status_code == status.HTTP_201_CREATED
created = self.objects.get(id=1)
assert created.content == 'foobar'
_test_all_signals(pre_create=1, post_create=1)
# Test for particularly ugly regression with m2m in browsable API
@ -436,6 +545,9 @@ class ExampleView(generics.ListCreateAPIView):
class TestM2MBrowsableAPI(TestCase):
def setUp(self):
reset_all_signals()
def test_m2m_in_browsable_api(self):
"""
Test for particularly ugly regression with m2m in browsable API
@ -444,6 +556,7 @@ class TestM2MBrowsableAPI(TestCase):
view = ExampleView().as_view()
response = view(request).render()
assert response.status_code == status.HTTP_200_OK
_test_all_signals(pre_read=1, post_read=1)
class InclusiveFilterBackend(object):
@ -492,6 +605,7 @@ class TestFilterBackendAppliedToViews(TestCase):
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
]
reset_all_signals()
def test_get_root_view_filters_by_name_with_filter_backend(self):
"""
@ -503,6 +617,7 @@ class TestFilterBackendAppliedToViews(TestCase):
assert response.status_code == status.HTTP_200_OK
assert len(response.data) == 1
assert response.data == [{'id': 1, 'text': 'foo'}]
_test_all_signals(pre_read=1, post_read=1)
def test_get_root_view_filters_out_all_models_with_exclusive_filter_backend(self):
"""
@ -513,6 +628,7 @@ class TestFilterBackendAppliedToViews(TestCase):
response = root_view(request).render()
assert response.status_code == status.HTTP_200_OK
assert response.data == []
_test_all_signals(pre_read=1, post_read=1)
def test_get_instance_view_filters_out_name_with_filter_backend(self):
"""
@ -523,6 +639,7 @@ class TestFilterBackendAppliedToViews(TestCase):
response = instance_view(request, pk=1).render()
assert response.status_code == status.HTTP_404_NOT_FOUND
assert response.data == {'detail': 'Not found.'}
_test_all_signals(pre_read=1)
def test_get_instance_view_will_return_single_object_when_filter_does_not_exclude_it(self):
"""
@ -533,6 +650,7 @@ class TestFilterBackendAppliedToViews(TestCase):
response = instance_view(request, pk=1).render()
assert response.status_code == status.HTTP_200_OK
assert response.data == {'id': 1, 'text': 'foo'}
_test_all_signals(pre_read=1, post_read=1)
def test_dynamic_serializer_form_in_browsable_api(self):
"""
@ -544,9 +662,13 @@ class TestFilterBackendAppliedToViews(TestCase):
content = response.content.decode('utf8')
assert 'field_b' in content
assert 'field_a' not in content
_test_all_signals(pre_read=1, post_read=1)
class TestGuardedQueryset(TestCase):
def setUp(self):
reset_all_signals()
def test_guarded_queryset(self):
class QuerysetAccessError(generics.ListAPIView):
queryset = BasicModel.objects.all()
@ -556,11 +678,14 @@ class TestGuardedQueryset(TestCase):
view = QuerysetAccessError.as_view()
request = factory.get('/')
_test_all_signals()
with pytest.raises(RuntimeError):
view(request).render()
class ApiViewsTests(TestCase):
def setUp(self):
reset_all_signals()
def test_create_api_view_post(self):
class MockCreateApiView(generics.CreateAPIView):
@ -572,6 +697,7 @@ class ApiViewsTests(TestCase):
view.post('test request', 'test arg', test_kwarg='test')
assert view.called is True
assert view.call_args == data
_test_all_signals(pre_create=1, post_create=1)
def test_destroy_api_view_delete(self):
class MockDestroyApiView(generics.DestroyAPIView):
@ -583,6 +709,7 @@ class ApiViewsTests(TestCase):
view.delete('test request', 'test arg', test_kwarg='test')
assert view.called is True
assert view.call_args == data
_test_all_signals(pre_destroy=1, post_destroy=1)
def test_update_api_view_partial_update(self):
class MockUpdateApiView(generics.UpdateAPIView):
@ -594,6 +721,7 @@ class ApiViewsTests(TestCase):
view.patch('test request', 'test arg', test_kwarg='test')
assert view.called is True
assert view.call_args == data
_test_all_signals(pre_update=1, post_update=1)
def test_retrieve_update_api_view_get(self):
class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView):
@ -605,6 +733,7 @@ class ApiViewsTests(TestCase):
view.get('test request', 'test arg', test_kwarg='test')
assert view.called is True
assert view.call_args == data
_test_all_signals(pre_read=1, post_read=1)
def test_retrieve_update_api_view_put(self):
class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView):
@ -616,6 +745,7 @@ class ApiViewsTests(TestCase):
view.put('test request', 'test arg', test_kwarg='test')
assert view.called is True
assert view.call_args == data
_test_all_signals(pre_update=1, post_update=1)
def test_retrieve_update_api_view_patch(self):
class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView):
@ -627,6 +757,7 @@ class ApiViewsTests(TestCase):
view.patch('test request', 'test arg', test_kwarg='test')
assert view.called is True
assert view.call_args == data
_test_all_signals(pre_update=1, post_update=1)
def test_retrieve_destroy_api_view_get(self):
class MockRetrieveDestroyUApiView(generics.RetrieveDestroyAPIView):
@ -638,6 +769,7 @@ class ApiViewsTests(TestCase):
view.get('test request', 'test arg', test_kwarg='test')
assert view.called is True
assert view.call_args == data
_test_all_signals(pre_read=1, post_read=1)
def test_retrieve_destroy_api_view_delete(self):
class MockRetrieveDestroyUApiView(generics.RetrieveDestroyAPIView):
@ -649,6 +781,7 @@ class ApiViewsTests(TestCase):
view.delete('test request', 'test arg', test_kwarg='test')
assert view.called is True
assert view.call_args == data
_test_all_signals(pre_destroy=1, post_destroy=1)
class GetObjectOr404Tests(TestCase):