diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 8d0bf284a..56349d2b9 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -10,6 +10,10 @@ from django.shortcuts import get_object_or_404 as _get_object_or_404 from rest_framework import mixins, views from rest_framework.settings import api_settings +from rest_framework.signals import ( + post_create, post_destroy, post_read, post_update, pre_create, pre_destroy, + pre_read, pre_update +) def get_object_or_404(queryset, *filter_args, **filter_kwargs): @@ -189,7 +193,10 @@ class CreateAPIView(mixins.CreateModelMixin, Concrete view for creating a model instance. """ def post(self, request, *args, **kwargs): - return self.create(request, *args, **kwargs) + pre_create.send(self, request=request) + resp = self.create(request, *args, **kwargs) + post_create.send(self, request=request, response=resp) + return resp class ListAPIView(mixins.ListModelMixin, @@ -198,7 +205,10 @@ class ListAPIView(mixins.ListModelMixin, Concrete view for listing a queryset. """ def get(self, request, *args, **kwargs): - return self.list(request, *args, **kwargs) + pre_read.send(self, request=request) + resp = self.list(request, *args, **kwargs) + post_read.send(self, request=request, response=resp) + return resp class RetrieveAPIView(mixins.RetrieveModelMixin, @@ -207,7 +217,10 @@ class RetrieveAPIView(mixins.RetrieveModelMixin, Concrete view for retrieving a model instance. """ def get(self, request, *args, **kwargs): - return self.retrieve(request, *args, **kwargs) + pre_read.send(self, request=request) + resp = self.retrieve(request, *args, **kwargs) + post_read.send(self, request=request, response=resp) + return resp class DestroyAPIView(mixins.DestroyModelMixin, @@ -216,7 +229,10 @@ class DestroyAPIView(mixins.DestroyModelMixin, Concrete view for deleting a model instance. """ def delete(self, request, *args, **kwargs): - return self.destroy(request, *args, **kwargs) + pre_destroy.send(self, request=request) + resp = self.destroy(request, *args, **kwargs) + post_destroy.send(self, request=request, response=resp) + return resp class UpdateAPIView(mixins.UpdateModelMixin, @@ -225,10 +241,16 @@ class UpdateAPIView(mixins.UpdateModelMixin, Concrete view for updating a model instance. """ def put(self, request, *args, **kwargs): - return self.update(request, *args, **kwargs) + pre_update.send(self, request=request) + resp = self.update(request, *args, **kwargs) + post_update.send(self, request=request, response=resp) + return resp def patch(self, request, *args, **kwargs): - return self.partial_update(request, *args, **kwargs) + pre_update.send(self, request=request) + resp = self.partial_update(request, *args, **kwargs) + post_update.send(self, request=request, response=resp) + return resp class ListCreateAPIView(mixins.ListModelMixin, @@ -238,10 +260,16 @@ class ListCreateAPIView(mixins.ListModelMixin, Concrete view for listing a queryset or creating a model instance. """ def get(self, request, *args, **kwargs): - return self.list(request, *args, **kwargs) + pre_read.send(self, request=request) + resp = self.list(request, *args, **kwargs) + post_read.send(self, request=request, response=resp) + return resp def post(self, request, *args, **kwargs): - return self.create(request, *args, **kwargs) + pre_create.send(self, request=request) + resp = self.create(request, *args, **kwargs) + post_create.send(self, request=request, response=resp) + return resp class RetrieveUpdateAPIView(mixins.RetrieveModelMixin, @@ -251,13 +279,22 @@ class RetrieveUpdateAPIView(mixins.RetrieveModelMixin, Concrete view for retrieving, updating a model instance. """ def get(self, request, *args, **kwargs): - return self.retrieve(request, *args, **kwargs) + pre_read.send(self, request=request) + resp = self.retrieve(request, *args, **kwargs) + post_read.send(self, request=request, response=resp) + return resp def put(self, request, *args, **kwargs): - return self.update(request, *args, **kwargs) + pre_update.send(self, request=request) + resp = self.update(request, *args, **kwargs) + post_update.send(self, request=request, response=resp) + return resp def patch(self, request, *args, **kwargs): - return self.partial_update(request, *args, **kwargs) + pre_update.send(self, request=request) + resp = self.partial_update(request, *args, **kwargs) + post_update.send(self, request=request, response=resp) + return resp class RetrieveDestroyAPIView(mixins.RetrieveModelMixin, @@ -267,10 +304,16 @@ class RetrieveDestroyAPIView(mixins.RetrieveModelMixin, Concrete view for retrieving or deleting a model instance. """ def get(self, request, *args, **kwargs): - return self.retrieve(request, *args, **kwargs) + pre_read.send(self, request=request) + resp = self.retrieve(request, *args, **kwargs) + post_read.send(self, request=request, response=resp) + return resp def delete(self, request, *args, **kwargs): - return self.destroy(request, *args, **kwargs) + pre_destroy.send(self, request=request) + resp = self.destroy(request, *args, **kwargs) + post_destroy.send(self, request=request, response=resp) + return resp class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin, @@ -281,13 +324,25 @@ class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin, Concrete view for retrieving, updating or deleting a model instance. """ def get(self, request, *args, **kwargs): - return self.retrieve(request, *args, **kwargs) + pre_read.send(self, request=request) + resp = self.retrieve(request, *args, **kwargs) + post_read.send(self, request=request, response=resp) + return resp def put(self, request, *args, **kwargs): - return self.update(request, *args, **kwargs) + pre_update.send(self, request=request) + resp = self.update(request, *args, **kwargs) + post_update.send(self, request=request, response=resp) + return resp def patch(self, request, *args, **kwargs): - return self.partial_update(request, *args, **kwargs) + pre_update.send(self, request=request) + resp = self.partial_update(request, *args, **kwargs) + post_update.send(self, request=request, response=resp) + return resp def delete(self, request, *args, **kwargs): - return self.destroy(request, *args, **kwargs) + pre_destroy.send(self, request=request) + resp = self.destroy(request, *args, **kwargs) + post_destroy.send(self, request=request, response=resp) + return resp diff --git a/rest_framework/signals.py b/rest_framework/signals.py new file mode 100644 index 000000000..9b55e8933 --- /dev/null +++ b/rest_framework/signals.py @@ -0,0 +1,10 @@ +from django.dispatch import Signal + +pre_create = Signal(providing_args=['request']) +post_create = Signal(providing_args=['request', 'response']) +pre_read = Signal(providing_args=['request']) +post_read = Signal(providing_args=['request', 'response']) +pre_update = Signal(providing_args=['request']) +post_update = Signal(providing_args=['request', 'response']) +pre_destroy = Signal(providing_args=['request']) +post_destroy = Signal(providing_args=['request', 'response']) diff --git a/tests/test_generics.py b/tests/test_generics.py index c0ff1c5c4..6f2e1f093 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -1,5 +1,7 @@ from __future__ import unicode_literals +from uuid import uuid4 + import pytest from django.db import models from django.http import Http404 @@ -9,6 +11,10 @@ from django.utils import six from rest_framework import generics, renderers, serializers, status from rest_framework.response import Response +from rest_framework.signals import ( + post_create, post_destroy, post_read, post_update, pre_create, pre_destroy, + pre_read, pre_update +) from rest_framework.test import APIRequestFactory from tests.models import ( BasicModel, ForeignKeySource, ForeignKeyTarget, RESTFrameworkModel, @@ -17,6 +23,83 @@ from tests.models import ( factory = APIRequestFactory() +# Signals dict to catch signals emitted per test +signals = { + "pre_create": 0, + "post_create": 0, + "pre_read": 0, + "post_read": 0, + "pre_update": 0, + "post_update": 0, + "pre_destroy": 0, + "post_destroy": 0 +} + + +# Signal handlers +def pre_create_handler(sender, **kwargs): + signals['pre_create'] += 1 + + +def post_create_handler(sender, **kwargs): + signals['post_create'] += 1 + + +def pre_read_handler(sender, **kwargs): + signals['pre_read'] += 1 + + +def post_read_handler(sender, **kwargs): + signals['post_read'] += 1 + + +def pre_update_handler(sender, **kwargs): + signals['pre_update'] += 1 + + +def post_update_handler(sender, **kwargs): + signals['post_update'] += 1 + + +def pre_destroy_handler(sender, **kwargs): + signals['pre_destroy'] += 1 + + +def post_destroy_handler(sender, **kwargs): + signals['post_destroy'] += 1 + + +# Connect signal handlers to the signals +pre_create.connect(pre_create_handler, dispatch_uid=uuid4().hex) +post_create.connect(post_create_handler, dispatch_uid=uuid4().hex) +pre_read.connect(pre_read_handler, dispatch_uid=uuid4().hex) +post_read.connect(post_read_handler, dispatch_uid=uuid4().hex) +pre_update.connect(pre_update_handler, dispatch_uid=uuid4().hex) +post_update.connect(post_update_handler, dispatch_uid=uuid4().hex) +pre_destroy.connect(pre_destroy_handler, dispatch_uid=uuid4().hex) +post_destroy.connect(post_destroy_handler, dispatch_uid=uuid4().hex) + + +# Resets all entries in the signal dict to 0 in test tearDowns +def reset_all_signals(): + for x in signals.keys(): + signals[x] = 0 + + +# Test each signal has been called the correct number of times +def _test_all_signals(pre_create=0, post_create=0, + pre_read=0, post_read=0, + pre_update=0, post_update=0, + pre_destroy=0, post_destroy=0): + assert signals['pre_create'] == pre_create + assert signals['post_create'] == post_create + assert signals['pre_read'] == pre_read + assert signals['post_read'] == post_read + assert signals['pre_update'] == pre_update + assert signals['post_update'] == post_update + assert signals['pre_destroy'] == pre_destroy + assert signals['post_destroy'] == post_destroy + # Models class SlugBasedModel(RESTFrameworkModel): @@ -92,6 +175,7 @@ class TestRootView(TestCase): for obj in self.objects.all() ] self.view = RootView.as_view() + reset_all_signals() def test_get_root_view(self): """ @@ -100,6 +184,7 @@ class TestRootView(TestCase): request = factory.get('/') with self.assertNumQueries(1): response = self.view(request).render() + _test_all_signals(pre_read=1, post_read=1) assert response.status_code == status.HTTP_200_OK assert response.data == self.data @@ -110,6 +195,10 @@ class TestRootView(TestCase): request = factory.head('/') with self.assertNumQueries(1): response = self.view(request).render() + # See here: + # https://github.com/django/django/blob/master/django/views/generic/base.py#L63 + # For why this triggers get triggers + _test_all_signals(pre_read=1, post_read=1) assert response.status_code == status.HTTP_200_OK def test_post_root_view(self): @@ -120,6 +209,7 @@ class TestRootView(TestCase): request = factory.post('/', data, format='json') with self.assertNumQueries(1): response = self.view(request).render() + _test_all_signals(pre_create=1, post_create=1) assert response.status_code == status.HTTP_201_CREATED assert response.data == {'id': 4, 'text': 'foobar'} created = self.objects.get(id=4) @@ -133,6 +223,7 @@ class TestRootView(TestCase): request = factory.put('/', data, format='json') with self.assertNumQueries(0): response = self.view(request).render() + _test_all_signals() assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED assert response.data == {"detail": 'Method "PUT" not allowed.'} @@ -143,6 +234,7 @@ class TestRootView(TestCase): request = factory.delete('/') with self.assertNumQueries(0): response = self.view(request).render() + _test_all_signals() assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED assert response.data == {"detail": 'Method "DELETE" not allowed.'} @@ -154,6 +246,7 @@ class TestRootView(TestCase): request = factory.post('/', data, format='json') with self.assertNumQueries(1): response = self.view(request).render() + _test_all_signals(pre_create=1, post_create=1) assert response.status_code == status.HTTP_201_CREATED assert response.data == {'id': 4, 'text': 'foobar'} created = self.objects.get(id=4) @@ -168,6 +261,7 @@ class TestRootView(TestCase): response = self.view(request).render() expected_error = 'Ensure this field has no more than 100 characters.' assert expected_error in response.rendered_content.decode('utf-8') + _test_all_signals(pre_create=1) EXPECTED_QUERIES_FOR_PUT = 2 @@ -188,6 +282,7 @@ class TestInstanceView(TestCase): ] self.view = InstanceView.as_view() self.slug_based_view = SlugBasedInstanceView.as_view() + reset_all_signals() def test_get_instance_view(self): """ @@ -196,6 +291,7 @@ class TestInstanceView(TestCase): request = factory.get('/1') with self.assertNumQueries(1): response = self.view(request, pk=1).render() + _test_all_signals(pre_read=1, post_read=1) assert response.status_code == status.HTTP_200_OK assert response.data == self.data[0] @@ -207,6 +303,7 @@ class TestInstanceView(TestCase): request = factory.post('/', data, format='json') with self.assertNumQueries(0): response = self.view(request).render() + _test_all_signals() assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED assert response.data == {"detail": 'Method "POST" not allowed.'} @@ -218,6 +315,7 @@ class TestInstanceView(TestCase): request = factory.put('/1', data, format='json') with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT): response = self.view(request, pk='1').render() + _test_all_signals(pre_update=1, post_update=1) assert response.status_code == status.HTTP_200_OK assert dict(response.data) == {'id': 1, 'text': 'foobar'} updated = self.objects.get(id=1) @@ -232,6 +330,7 @@ class TestInstanceView(TestCase): with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT): response = self.view(request, pk=1).render() + _test_all_signals(pre_update=1, post_update=1) assert response.status_code == status.HTTP_200_OK assert response.data == {'id': 1, 'text': 'foobar'} updated = self.objects.get(id=1) @@ -244,6 +343,7 @@ class TestInstanceView(TestCase): request = factory.delete('/1') with self.assertNumQueries(2): response = self.view(request, pk=1).render() + _test_all_signals(pre_destroy=1, post_destroy=1) assert response.status_code == status.HTTP_204_NO_CONTENT assert response.content == six.b('') ids = [obj.id for obj in self.objects.all()] @@ -257,6 +357,7 @@ class TestInstanceView(TestCase): request = factory.get('/a') with self.assertNumQueries(0): response = self.view(request, pk='a').render() + _test_all_signals(pre_read=1) assert response.status_code == status.HTTP_404_NOT_FOUND def test_put_cannot_set_id(self): @@ -282,6 +383,7 @@ class TestInstanceView(TestCase): request = factory.put('/1', data, format='json') with self.assertNumQueries(1): response = self.view(request, pk=1).render() + _test_all_signals(pre_update=1) assert response.status_code == status.HTTP_404_NOT_FOUND def test_put_to_filtered_out_instance(self): @@ -294,6 +396,7 @@ class TestInstanceView(TestCase): request = factory.put('/{0}'.format(filtered_out_pk), data, format='json') response = self.view(request, pk=filtered_out_pk).render() assert response.status_code == status.HTTP_404_NOT_FOUND + _test_all_signals(pre_update=1) def test_patch_cannot_create_an_object(self): """ @@ -305,6 +408,7 @@ class TestInstanceView(TestCase): response = self.view(request, pk=999).render() assert response.status_code == status.HTTP_404_NOT_FOUND assert not self.objects.filter(id=999).exists() + _test_all_signals(pre_update=1) def test_put_error_instance_view(self): """ @@ -315,6 +419,7 @@ class TestInstanceView(TestCase): response = self.view(request, pk=1).render() expected_error = 'Ensure this field has no more than 100 characters.' assert expected_error in response.rendered_content.decode('utf-8') + _test_all_signals(pre_update=1) class TestFKInstanceView(TestCase): @@ -366,6 +471,7 @@ class TestOverriddenGetObject(TestCase): return get_object_or_404(BasicModel.objects.all(), id=pk) self.view = OverriddenGetObjectView.as_view() + reset_all_signals() def test_overridden_get_object_view(self): """ @@ -374,6 +480,7 @@ class TestOverriddenGetObject(TestCase): request = factory.get('/1') with self.assertNumQueries(1): response = self.view(request, pk=1).render() + _test_all_signals(pre_read=1, post_read=1) assert response.status_code == status.HTTP_200_OK assert response.data == self.data[0] @@ -395,6 +502,7 @@ class TestCreateModelWithAutoNowAddField(TestCase): def setUp(self): self.objects = Comment.objects self.view = CommentView.as_view() + reset_all_signals() def test_create_model_with_auto_now_add_field(self): """ @@ -408,6 +516,7 @@ class TestCreateModelWithAutoNowAddField(TestCase): assert response.status_code == status.HTTP_201_CREATED created = self.objects.get(id=1) assert created.content == 'foobar' + _test_all_signals(pre_create=1, post_create=1) # Test for particularly ugly regression with m2m in browsable API @@ -436,6 +545,9 @@ class ExampleView(generics.ListCreateAPIView): class TestM2MBrowsableAPI(TestCase): + def setUp(self): + reset_all_signals() + def test_m2m_in_browsable_api(self): """ Test for particularly ugly regression with m2m in browsable API @@ -444,6 +556,7 @@ class TestM2MBrowsableAPI(TestCase): view = ExampleView().as_view() response = view(request).render() assert response.status_code == status.HTTP_200_OK + _test_all_signals(pre_read=1, post_read=1) class InclusiveFilterBackend(object): @@ -492,6 +605,7 @@ class TestFilterBackendAppliedToViews(TestCase): {'id': obj.id, 'text': obj.text} for obj in self.objects.all() ] + reset_all_signals() def test_get_root_view_filters_by_name_with_filter_backend(self): """ @@ -503,6 +617,7 @@ class TestFilterBackendAppliedToViews(TestCase): assert response.status_code == status.HTTP_200_OK assert len(response.data) == 1 assert response.data == [{'id': 1, 'text': 'foo'}] + _test_all_signals(pre_read=1, post_read=1) def test_get_root_view_filters_out_all_models_with_exclusive_filter_backend(self): """ @@ -513,6 +628,7 @@ class TestFilterBackendAppliedToViews(TestCase): response = root_view(request).render() assert response.status_code == status.HTTP_200_OK assert response.data == [] + _test_all_signals(pre_read=1, post_read=1) def test_get_instance_view_filters_out_name_with_filter_backend(self): """ @@ -523,6 +639,7 @@ class TestFilterBackendAppliedToViews(TestCase): response = instance_view(request, pk=1).render() assert response.status_code == status.HTTP_404_NOT_FOUND assert response.data == {'detail': 'Not found.'} + _test_all_signals(pre_read=1) def test_get_instance_view_will_return_single_object_when_filter_does_not_exclude_it(self): """ @@ -533,6 +650,7 @@ class TestFilterBackendAppliedToViews(TestCase): response = instance_view(request, pk=1).render() assert response.status_code == status.HTTP_200_OK assert response.data == {'id': 1, 'text': 'foo'} + _test_all_signals(pre_read=1, post_read=1) def test_dynamic_serializer_form_in_browsable_api(self): """ @@ -544,9 +662,13 @@ class TestFilterBackendAppliedToViews(TestCase): content = response.content.decode('utf8') assert 'field_b' in content assert 'field_a' not in content + _test_all_signals(pre_read=1, post_read=1) class TestGuardedQueryset(TestCase): + def setUp(self): + reset_all_signals() + def test_guarded_queryset(self): class QuerysetAccessError(generics.ListAPIView): queryset = BasicModel.objects.all() @@ -556,11 +678,14 @@ class TestGuardedQueryset(TestCase): view = QuerysetAccessError.as_view() request = factory.get('/') + _test_all_signals() with pytest.raises(RuntimeError): view(request).render() class ApiViewsTests(TestCase): + def setUp(self): + reset_all_signals() def test_create_api_view_post(self): class MockCreateApiView(generics.CreateAPIView): @@ -572,6 +697,7 @@ class ApiViewsTests(TestCase): view.post('test request', 'test arg', test_kwarg='test') assert view.called is True assert view.call_args == data + _test_all_signals(pre_create=1, post_create=1) def test_destroy_api_view_delete(self): class MockDestroyApiView(generics.DestroyAPIView): @@ -583,6 +709,7 @@ class ApiViewsTests(TestCase): view.delete('test request', 'test arg', test_kwarg='test') assert view.called is True assert view.call_args == data + _test_all_signals(pre_destroy=1, post_destroy=1) def test_update_api_view_partial_update(self): class MockUpdateApiView(generics.UpdateAPIView): @@ -594,6 +721,7 @@ class ApiViewsTests(TestCase): view.patch('test request', 'test arg', test_kwarg='test') assert view.called is True assert view.call_args == data + _test_all_signals(pre_update=1, post_update=1) def test_retrieve_update_api_view_get(self): class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView): @@ -605,6 +733,7 @@ class ApiViewsTests(TestCase): view.get('test request', 'test arg', test_kwarg='test') assert view.called is True assert view.call_args == data + _test_all_signals(pre_read=1, post_read=1) def test_retrieve_update_api_view_put(self): class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView): @@ -616,6 +745,7 @@ class ApiViewsTests(TestCase): view.put('test request', 'test arg', test_kwarg='test') assert view.called is True assert view.call_args == data + _test_all_signals(pre_update=1, post_update=1) def test_retrieve_update_api_view_patch(self): class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView): @@ -627,6 +757,7 @@ class ApiViewsTests(TestCase): view.patch('test request', 'test arg', test_kwarg='test') assert view.called is True assert view.call_args == data + _test_all_signals(pre_update=1, post_update=1) def test_retrieve_destroy_api_view_get(self): class MockRetrieveDestroyUApiView(generics.RetrieveDestroyAPIView): @@ -638,6 +769,7 @@ class ApiViewsTests(TestCase): view.get('test request', 'test arg', test_kwarg='test') assert view.called is True assert view.call_args == data + _test_all_signals(pre_read=1, post_read=1) def test_retrieve_destroy_api_view_delete(self): class MockRetrieveDestroyUApiView(generics.RetrieveDestroyAPIView): @@ -649,6 +781,7 @@ class ApiViewsTests(TestCase): view.delete('test request', 'test arg', test_kwarg='test') assert view.called is True assert view.call_args == data + _test_all_signals(pre_destroy=1, post_destroy=1) class GetObjectOr404Tests(TestCase):