diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index 8d0bf284a..56349d2b9 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -10,6 +10,10 @@ from django.shortcuts import get_object_or_404 as _get_object_or_404
from rest_framework import mixins, views
from rest_framework.settings import api_settings
+from rest_framework.signals import (
+ post_create, post_destroy, post_read, post_update, pre_create, pre_destroy,
+ pre_read, pre_update
+)
def get_object_or_404(queryset, *filter_args, **filter_kwargs):
@@ -189,7 +193,10 @@ class CreateAPIView(mixins.CreateModelMixin,
Concrete view for creating a model instance.
"""
def post(self, request, *args, **kwargs):
- return self.create(request, *args, **kwargs)
+ pre_create.send(self, request=request)
+ resp = self.create(request, *args, **kwargs)
+ post_create.send(self, request=request, response=resp)
+ return resp
class ListAPIView(mixins.ListModelMixin,
@@ -198,7 +205,10 @@ class ListAPIView(mixins.ListModelMixin,
Concrete view for listing a queryset.
"""
def get(self, request, *args, **kwargs):
- return self.list(request, *args, **kwargs)
+ pre_read.send(self, request=request)
+ resp = self.list(request, *args, **kwargs)
+ post_read.send(self, request=request, response=resp)
+ return resp
class RetrieveAPIView(mixins.RetrieveModelMixin,
@@ -207,7 +217,10 @@ class RetrieveAPIView(mixins.RetrieveModelMixin,
Concrete view for retrieving a model instance.
"""
def get(self, request, *args, **kwargs):
- return self.retrieve(request, *args, **kwargs)
+ pre_read.send(self, request=request)
+ resp = self.retrieve(request, *args, **kwargs)
+ post_read.send(self, request=request, response=resp)
+ return resp
class DestroyAPIView(mixins.DestroyModelMixin,
@@ -216,7 +229,10 @@ class DestroyAPIView(mixins.DestroyModelMixin,
Concrete view for deleting a model instance.
"""
def delete(self, request, *args, **kwargs):
- return self.destroy(request, *args, **kwargs)
+ pre_destroy.send(self, request=request)
+ resp = self.destroy(request, *args, **kwargs)
+ post_destroy.send(self, request=request, response=resp)
+ return resp
class UpdateAPIView(mixins.UpdateModelMixin,
@@ -225,10 +241,16 @@ class UpdateAPIView(mixins.UpdateModelMixin,
Concrete view for updating a model instance.
"""
def put(self, request, *args, **kwargs):
- return self.update(request, *args, **kwargs)
+ pre_update.send(self, request=request)
+ resp = self.update(request, *args, **kwargs)
+ post_update.send(self, request=request, response=resp)
+ return resp
def patch(self, request, *args, **kwargs):
- return self.partial_update(request, *args, **kwargs)
+ pre_update.send(self, request=request)
+ resp = self.partial_update(request, *args, **kwargs)
+ post_update.send(self, request=request, response=resp)
+ return resp
class ListCreateAPIView(mixins.ListModelMixin,
@@ -238,10 +260,16 @@ class ListCreateAPIView(mixins.ListModelMixin,
Concrete view for listing a queryset or creating a model instance.
"""
def get(self, request, *args, **kwargs):
- return self.list(request, *args, **kwargs)
+ pre_read.send(self, request=request)
+ resp = self.list(request, *args, **kwargs)
+ post_read.send(self, request=request, response=resp)
+ return resp
def post(self, request, *args, **kwargs):
- return self.create(request, *args, **kwargs)
+ pre_create.send(self, request=request)
+ resp = self.create(request, *args, **kwargs)
+ post_create.send(self, request=request, response=resp)
+ return resp
class RetrieveUpdateAPIView(mixins.RetrieveModelMixin,
@@ -251,13 +279,22 @@ class RetrieveUpdateAPIView(mixins.RetrieveModelMixin,
Concrete view for retrieving, updating a model instance.
"""
def get(self, request, *args, **kwargs):
- return self.retrieve(request, *args, **kwargs)
+ pre_read.send(self, request=request)
+ resp = self.retrieve(request, *args, **kwargs)
+ post_read.send(self, request=request, response=resp)
+ return resp
def put(self, request, *args, **kwargs):
- return self.update(request, *args, **kwargs)
+ pre_update.send(self, request=request)
+ resp = self.update(request, *args, **kwargs)
+ post_update.send(self, request=request, response=resp)
+ return resp
def patch(self, request, *args, **kwargs):
- return self.partial_update(request, *args, **kwargs)
+ pre_update.send(self, request=request)
+ resp = self.partial_update(request, *args, **kwargs)
+ post_update.send(self, request=request, response=resp)
+ return resp
class RetrieveDestroyAPIView(mixins.RetrieveModelMixin,
@@ -267,10 +304,16 @@ class RetrieveDestroyAPIView(mixins.RetrieveModelMixin,
Concrete view for retrieving or deleting a model instance.
"""
def get(self, request, *args, **kwargs):
- return self.retrieve(request, *args, **kwargs)
+ pre_read.send(self, request=request)
+ resp = self.retrieve(request, *args, **kwargs)
+ post_read.send(self, request=request, response=resp)
+ return resp
def delete(self, request, *args, **kwargs):
- return self.destroy(request, *args, **kwargs)
+ pre_destroy.send(self, request=request)
+ resp = self.destroy(request, *args, **kwargs)
+ post_destroy.send(self, request=request, response=resp)
+ return resp
class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin,
@@ -281,13 +324,25 @@ class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin,
Concrete view for retrieving, updating or deleting a model instance.
"""
def get(self, request, *args, **kwargs):
- return self.retrieve(request, *args, **kwargs)
+ pre_read.send(self, request=request)
+ resp = self.retrieve(request, *args, **kwargs)
+ post_read.send(self, request=request, response=resp)
+ return resp
def put(self, request, *args, **kwargs):
- return self.update(request, *args, **kwargs)
+ pre_update.send(self, request=request)
+ resp = self.update(request, *args, **kwargs)
+ post_update.send(self, request=request, response=resp)
+ return resp
def patch(self, request, *args, **kwargs):
- return self.partial_update(request, *args, **kwargs)
+ pre_update.send(self, request=request)
+ resp = self.partial_update(request, *args, **kwargs)
+ post_update.send(self, request=request, response=resp)
+ return resp
def delete(self, request, *args, **kwargs):
- return self.destroy(request, *args, **kwargs)
+ pre_destroy.send(self, request=request)
+ resp = self.destroy(request, *args, **kwargs)
+ post_destroy.send(self, request=request, response=resp)
+ return resp
diff --git a/rest_framework/signals.py b/rest_framework/signals.py
new file mode 100644
index 000000000..9b55e8933
--- /dev/null
+++ b/rest_framework/signals.py
@@ -0,0 +1,10 @@
+from django.dispatch import Signal
+
+pre_create = Signal(providing_args=['request'])
+post_create = Signal(providing_args=['request', 'response'])
+pre_read = Signal(providing_args=['request'])
+post_read = Signal(providing_args=['request', 'response'])
+pre_update = Signal(providing_args=['request'])
+post_update = Signal(providing_args=['request', 'response'])
+pre_destroy = Signal(providing_args=['request'])
+post_destroy = Signal(providing_args=['request', 'response'])
diff --git a/tests/test_generics.py b/tests/test_generics.py
index c0ff1c5c4..6f2e1f093 100644
--- a/tests/test_generics.py
+++ b/tests/test_generics.py
@@ -1,5 +1,7 @@
from __future__ import unicode_literals
+from uuid import uuid4
+
import pytest
from django.db import models
from django.http import Http404
@@ -9,6 +11,10 @@ from django.utils import six
from rest_framework import generics, renderers, serializers, status
from rest_framework.response import Response
+from rest_framework.signals import (
+ post_create, post_destroy, post_read, post_update, pre_create, pre_destroy,
+ pre_read, pre_update
+)
from rest_framework.test import APIRequestFactory
from tests.models import (
BasicModel, ForeignKeySource, ForeignKeyTarget, RESTFrameworkModel,
@@ -17,6 +23,83 @@ from tests.models import (
factory = APIRequestFactory()
+# Signals dict to catch signals emitted per test
+signals = {
+ "pre_create": 0,
+ "post_create": 0,
+ "pre_read": 0,
+ "post_read": 0,
+ "pre_update": 0,
+ "post_update": 0,
+ "pre_destroy": 0,
+ "post_destroy": 0
+}
+
+
+# Signal handlers
+def pre_create_handler(sender, **kwargs):
+ signals['pre_create'] += 1
+
+
+def post_create_handler(sender, **kwargs):
+ signals['post_create'] += 1
+
+
+def pre_read_handler(sender, **kwargs):
+ signals['pre_read'] += 1
+
+
+def post_read_handler(sender, **kwargs):
+ signals['post_read'] += 1
+
+
+def pre_update_handler(sender, **kwargs):
+ signals['pre_update'] += 1
+
+
+def post_update_handler(sender, **kwargs):
+ signals['post_update'] += 1
+
+
+def pre_destroy_handler(sender, **kwargs):
+ signals['pre_destroy'] += 1
+
+
+def post_destroy_handler(sender, **kwargs):
+ signals['post_destroy'] += 1
+
+
+# Connect signal handlers to the signals
+pre_create.connect(pre_create_handler, dispatch_uid=uuid4().hex)
+post_create.connect(post_create_handler, dispatch_uid=uuid4().hex)
+pre_read.connect(pre_read_handler, dispatch_uid=uuid4().hex)
+post_read.connect(post_read_handler, dispatch_uid=uuid4().hex)
+pre_update.connect(pre_update_handler, dispatch_uid=uuid4().hex)
+post_update.connect(post_update_handler, dispatch_uid=uuid4().hex)
+pre_destroy.connect(pre_destroy_handler, dispatch_uid=uuid4().hex)
+post_destroy.connect(post_destroy_handler, dispatch_uid=uuid4().hex)
+
+
+# Resets all entries in the signal dict to 0 in test tearDowns
+def reset_all_signals():
+ for x in signals.keys():
+ signals[x] = 0
+
+
+# Test each signal has been called the correct number of times
+def _test_all_signals(pre_create=0, post_create=0,
+ pre_read=0, post_read=0,
+ pre_update=0, post_update=0,
+ pre_destroy=0, post_destroy=0):
+ assert signals['pre_create'] == pre_create
+ assert signals['post_create'] == post_create
+ assert signals['pre_read'] == pre_read
+ assert signals['post_read'] == post_read
+ assert signals['pre_update'] == pre_update
+ assert signals['post_update'] == post_update
+ assert signals['pre_destroy'] == pre_destroy
+ assert signals['post_destroy'] == post_destroy
+
# Models
class SlugBasedModel(RESTFrameworkModel):
@@ -92,6 +175,7 @@ class TestRootView(TestCase):
for obj in self.objects.all()
]
self.view = RootView.as_view()
+ reset_all_signals()
def test_get_root_view(self):
"""
@@ -100,6 +184,7 @@ class TestRootView(TestCase):
request = factory.get('/')
with self.assertNumQueries(1):
response = self.view(request).render()
+ _test_all_signals(pre_read=1, post_read=1)
assert response.status_code == status.HTTP_200_OK
assert response.data == self.data
@@ -110,6 +195,10 @@ class TestRootView(TestCase):
request = factory.head('/')
with self.assertNumQueries(1):
response = self.view(request).render()
+ # See here:
+ # https://github.com/django/django/blob/master/django/views/generic/base.py#L63
+ # For why this triggers get triggers
+ _test_all_signals(pre_read=1, post_read=1)
assert response.status_code == status.HTTP_200_OK
def test_post_root_view(self):
@@ -120,6 +209,7 @@ class TestRootView(TestCase):
request = factory.post('/', data, format='json')
with self.assertNumQueries(1):
response = self.view(request).render()
+ _test_all_signals(pre_create=1, post_create=1)
assert response.status_code == status.HTTP_201_CREATED
assert response.data == {'id': 4, 'text': 'foobar'}
created = self.objects.get(id=4)
@@ -133,6 +223,7 @@ class TestRootView(TestCase):
request = factory.put('/', data, format='json')
with self.assertNumQueries(0):
response = self.view(request).render()
+ _test_all_signals()
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
assert response.data == {"detail": 'Method "PUT" not allowed.'}
@@ -143,6 +234,7 @@ class TestRootView(TestCase):
request = factory.delete('/')
with self.assertNumQueries(0):
response = self.view(request).render()
+ _test_all_signals()
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
assert response.data == {"detail": 'Method "DELETE" not allowed.'}
@@ -154,6 +246,7 @@ class TestRootView(TestCase):
request = factory.post('/', data, format='json')
with self.assertNumQueries(1):
response = self.view(request).render()
+ _test_all_signals(pre_create=1, post_create=1)
assert response.status_code == status.HTTP_201_CREATED
assert response.data == {'id': 4, 'text': 'foobar'}
created = self.objects.get(id=4)
@@ -168,6 +261,7 @@ class TestRootView(TestCase):
response = self.view(request).render()
expected_error = 'Ensure this field has no more than 100 characters.'
assert expected_error in response.rendered_content.decode('utf-8')
+ _test_all_signals(pre_create=1)
EXPECTED_QUERIES_FOR_PUT = 2
@@ -188,6 +282,7 @@ class TestInstanceView(TestCase):
]
self.view = InstanceView.as_view()
self.slug_based_view = SlugBasedInstanceView.as_view()
+ reset_all_signals()
def test_get_instance_view(self):
"""
@@ -196,6 +291,7 @@ class TestInstanceView(TestCase):
request = factory.get('/1')
with self.assertNumQueries(1):
response = self.view(request, pk=1).render()
+ _test_all_signals(pre_read=1, post_read=1)
assert response.status_code == status.HTTP_200_OK
assert response.data == self.data[0]
@@ -207,6 +303,7 @@ class TestInstanceView(TestCase):
request = factory.post('/', data, format='json')
with self.assertNumQueries(0):
response = self.view(request).render()
+ _test_all_signals()
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
assert response.data == {"detail": 'Method "POST" not allowed.'}
@@ -218,6 +315,7 @@ class TestInstanceView(TestCase):
request = factory.put('/1', data, format='json')
with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT):
response = self.view(request, pk='1').render()
+ _test_all_signals(pre_update=1, post_update=1)
assert response.status_code == status.HTTP_200_OK
assert dict(response.data) == {'id': 1, 'text': 'foobar'}
updated = self.objects.get(id=1)
@@ -232,6 +330,7 @@ class TestInstanceView(TestCase):
with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT):
response = self.view(request, pk=1).render()
+ _test_all_signals(pre_update=1, post_update=1)
assert response.status_code == status.HTTP_200_OK
assert response.data == {'id': 1, 'text': 'foobar'}
updated = self.objects.get(id=1)
@@ -244,6 +343,7 @@ class TestInstanceView(TestCase):
request = factory.delete('/1')
with self.assertNumQueries(2):
response = self.view(request, pk=1).render()
+ _test_all_signals(pre_destroy=1, post_destroy=1)
assert response.status_code == status.HTTP_204_NO_CONTENT
assert response.content == six.b('')
ids = [obj.id for obj in self.objects.all()]
@@ -257,6 +357,7 @@ class TestInstanceView(TestCase):
request = factory.get('/a')
with self.assertNumQueries(0):
response = self.view(request, pk='a').render()
+ _test_all_signals(pre_read=1)
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_put_cannot_set_id(self):
@@ -282,6 +383,7 @@ class TestInstanceView(TestCase):
request = factory.put('/1', data, format='json')
with self.assertNumQueries(1):
response = self.view(request, pk=1).render()
+ _test_all_signals(pre_update=1)
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_put_to_filtered_out_instance(self):
@@ -294,6 +396,7 @@ class TestInstanceView(TestCase):
request = factory.put('/{0}'.format(filtered_out_pk), data, format='json')
response = self.view(request, pk=filtered_out_pk).render()
assert response.status_code == status.HTTP_404_NOT_FOUND
+ _test_all_signals(pre_update=1)
def test_patch_cannot_create_an_object(self):
"""
@@ -305,6 +408,7 @@ class TestInstanceView(TestCase):
response = self.view(request, pk=999).render()
assert response.status_code == status.HTTP_404_NOT_FOUND
assert not self.objects.filter(id=999).exists()
+ _test_all_signals(pre_update=1)
def test_put_error_instance_view(self):
"""
@@ -315,6 +419,7 @@ class TestInstanceView(TestCase):
response = self.view(request, pk=1).render()
expected_error = 'Ensure this field has no more than 100 characters.'
assert expected_error in response.rendered_content.decode('utf-8')
+ _test_all_signals(pre_update=1)
class TestFKInstanceView(TestCase):
@@ -366,6 +471,7 @@ class TestOverriddenGetObject(TestCase):
return get_object_or_404(BasicModel.objects.all(), id=pk)
self.view = OverriddenGetObjectView.as_view()
+ reset_all_signals()
def test_overridden_get_object_view(self):
"""
@@ -374,6 +480,7 @@ class TestOverriddenGetObject(TestCase):
request = factory.get('/1')
with self.assertNumQueries(1):
response = self.view(request, pk=1).render()
+ _test_all_signals(pre_read=1, post_read=1)
assert response.status_code == status.HTTP_200_OK
assert response.data == self.data[0]
@@ -395,6 +502,7 @@ class TestCreateModelWithAutoNowAddField(TestCase):
def setUp(self):
self.objects = Comment.objects
self.view = CommentView.as_view()
+ reset_all_signals()
def test_create_model_with_auto_now_add_field(self):
"""
@@ -408,6 +516,7 @@ class TestCreateModelWithAutoNowAddField(TestCase):
assert response.status_code == status.HTTP_201_CREATED
created = self.objects.get(id=1)
assert created.content == 'foobar'
+ _test_all_signals(pre_create=1, post_create=1)
# Test for particularly ugly regression with m2m in browsable API
@@ -436,6 +545,9 @@ class ExampleView(generics.ListCreateAPIView):
class TestM2MBrowsableAPI(TestCase):
+ def setUp(self):
+ reset_all_signals()
+
def test_m2m_in_browsable_api(self):
"""
Test for particularly ugly regression with m2m in browsable API
@@ -444,6 +556,7 @@ class TestM2MBrowsableAPI(TestCase):
view = ExampleView().as_view()
response = view(request).render()
assert response.status_code == status.HTTP_200_OK
+ _test_all_signals(pre_read=1, post_read=1)
class InclusiveFilterBackend(object):
@@ -492,6 +605,7 @@ class TestFilterBackendAppliedToViews(TestCase):
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
]
+ reset_all_signals()
def test_get_root_view_filters_by_name_with_filter_backend(self):
"""
@@ -503,6 +617,7 @@ class TestFilterBackendAppliedToViews(TestCase):
assert response.status_code == status.HTTP_200_OK
assert len(response.data) == 1
assert response.data == [{'id': 1, 'text': 'foo'}]
+ _test_all_signals(pre_read=1, post_read=1)
def test_get_root_view_filters_out_all_models_with_exclusive_filter_backend(self):
"""
@@ -513,6 +628,7 @@ class TestFilterBackendAppliedToViews(TestCase):
response = root_view(request).render()
assert response.status_code == status.HTTP_200_OK
assert response.data == []
+ _test_all_signals(pre_read=1, post_read=1)
def test_get_instance_view_filters_out_name_with_filter_backend(self):
"""
@@ -523,6 +639,7 @@ class TestFilterBackendAppliedToViews(TestCase):
response = instance_view(request, pk=1).render()
assert response.status_code == status.HTTP_404_NOT_FOUND
assert response.data == {'detail': 'Not found.'}
+ _test_all_signals(pre_read=1)
def test_get_instance_view_will_return_single_object_when_filter_does_not_exclude_it(self):
"""
@@ -533,6 +650,7 @@ class TestFilterBackendAppliedToViews(TestCase):
response = instance_view(request, pk=1).render()
assert response.status_code == status.HTTP_200_OK
assert response.data == {'id': 1, 'text': 'foo'}
+ _test_all_signals(pre_read=1, post_read=1)
def test_dynamic_serializer_form_in_browsable_api(self):
"""
@@ -544,9 +662,13 @@ class TestFilterBackendAppliedToViews(TestCase):
content = response.content.decode('utf8')
assert 'field_b' in content
assert 'field_a' not in content
+ _test_all_signals(pre_read=1, post_read=1)
class TestGuardedQueryset(TestCase):
+ def setUp(self):
+ reset_all_signals()
+
def test_guarded_queryset(self):
class QuerysetAccessError(generics.ListAPIView):
queryset = BasicModel.objects.all()
@@ -556,11 +678,14 @@ class TestGuardedQueryset(TestCase):
view = QuerysetAccessError.as_view()
request = factory.get('/')
+ _test_all_signals()
with pytest.raises(RuntimeError):
view(request).render()
class ApiViewsTests(TestCase):
+ def setUp(self):
+ reset_all_signals()
def test_create_api_view_post(self):
class MockCreateApiView(generics.CreateAPIView):
@@ -572,6 +697,7 @@ class ApiViewsTests(TestCase):
view.post('test request', 'test arg', test_kwarg='test')
assert view.called is True
assert view.call_args == data
+ _test_all_signals(pre_create=1, post_create=1)
def test_destroy_api_view_delete(self):
class MockDestroyApiView(generics.DestroyAPIView):
@@ -583,6 +709,7 @@ class ApiViewsTests(TestCase):
view.delete('test request', 'test arg', test_kwarg='test')
assert view.called is True
assert view.call_args == data
+ _test_all_signals(pre_destroy=1, post_destroy=1)
def test_update_api_view_partial_update(self):
class MockUpdateApiView(generics.UpdateAPIView):
@@ -594,6 +721,7 @@ class ApiViewsTests(TestCase):
view.patch('test request', 'test arg', test_kwarg='test')
assert view.called is True
assert view.call_args == data
+ _test_all_signals(pre_update=1, post_update=1)
def test_retrieve_update_api_view_get(self):
class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView):
@@ -605,6 +733,7 @@ class ApiViewsTests(TestCase):
view.get('test request', 'test arg', test_kwarg='test')
assert view.called is True
assert view.call_args == data
+ _test_all_signals(pre_read=1, post_read=1)
def test_retrieve_update_api_view_put(self):
class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView):
@@ -616,6 +745,7 @@ class ApiViewsTests(TestCase):
view.put('test request', 'test arg', test_kwarg='test')
assert view.called is True
assert view.call_args == data
+ _test_all_signals(pre_update=1, post_update=1)
def test_retrieve_update_api_view_patch(self):
class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView):
@@ -627,6 +757,7 @@ class ApiViewsTests(TestCase):
view.patch('test request', 'test arg', test_kwarg='test')
assert view.called is True
assert view.call_args == data
+ _test_all_signals(pre_update=1, post_update=1)
def test_retrieve_destroy_api_view_get(self):
class MockRetrieveDestroyUApiView(generics.RetrieveDestroyAPIView):
@@ -638,6 +769,7 @@ class ApiViewsTests(TestCase):
view.get('test request', 'test arg', test_kwarg='test')
assert view.called is True
assert view.call_args == data
+ _test_all_signals(pre_read=1, post_read=1)
def test_retrieve_destroy_api_view_delete(self):
class MockRetrieveDestroyUApiView(generics.RetrieveDestroyAPIView):
@@ -649,6 +781,7 @@ class ApiViewsTests(TestCase):
view.delete('test request', 'test arg', test_kwarg='test')
assert view.called is True
assert view.call_args == data
+ _test_all_signals(pre_destroy=1, post_destroy=1)
class GetObjectOr404Tests(TestCase):