This commit is contained in:
Brian Balsamo 2018-10-11 17:27:46 +00:00 committed by GitHub
commit cf702f8e81
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 215 additions and 17 deletions

View File

@ -10,6 +10,10 @@ from django.shortcuts import get_object_or_404 as _get_object_or_404
from rest_framework import mixins, views
from rest_framework.settings import api_settings
from rest_framework.signals import (
post_create, post_destroy, post_read, post_update, pre_create, pre_destroy,
pre_read, pre_update
)
def get_object_or_404(queryset, *filter_args, **filter_kwargs):
@ -189,7 +193,10 @@ class CreateAPIView(mixins.CreateModelMixin,
Concrete view for creating a model instance.
"""
def post(self, request, *args, **kwargs):
return self.create(request, *args, **kwargs)
pre_create.send(self, request=request)
resp = self.create(request, *args, **kwargs)
post_create.send(self, request=request, response=resp)
return resp
class ListAPIView(mixins.ListModelMixin,
@ -198,7 +205,10 @@ class ListAPIView(mixins.ListModelMixin,
Concrete view for listing a queryset.
"""
def get(self, request, *args, **kwargs):
return self.list(request, *args, **kwargs)
pre_read.send(self, request=request)
resp = self.list(request, *args, **kwargs)
post_read.send(self, request=request, response=resp)
return resp
class RetrieveAPIView(mixins.RetrieveModelMixin,
@ -207,7 +217,10 @@ class RetrieveAPIView(mixins.RetrieveModelMixin,
Concrete view for retrieving a model instance.
"""
def get(self, request, *args, **kwargs):
return self.retrieve(request, *args, **kwargs)
pre_read.send(self, request=request)
resp = self.retrieve(request, *args, **kwargs)
post_read.send(self, request=request, response=resp)
return resp
class DestroyAPIView(mixins.DestroyModelMixin,
@ -216,7 +229,10 @@ class DestroyAPIView(mixins.DestroyModelMixin,
Concrete view for deleting a model instance.
"""
def delete(self, request, *args, **kwargs):
return self.destroy(request, *args, **kwargs)
pre_destroy.send(self, request=request)
resp = self.destroy(request, *args, **kwargs)
post_destroy.send(self, request=request, response=resp)
return resp
class UpdateAPIView(mixins.UpdateModelMixin,
@ -225,10 +241,16 @@ class UpdateAPIView(mixins.UpdateModelMixin,
Concrete view for updating a model instance.
"""
def put(self, request, *args, **kwargs):
return self.update(request, *args, **kwargs)
pre_update.send(self, request=request)
resp = self.update(request, *args, **kwargs)
post_update.send(self, request=request, response=resp)
return resp
def patch(self, request, *args, **kwargs):
return self.partial_update(request, *args, **kwargs)
pre_update.send(self, request=request)
resp = self.partial_update(request, *args, **kwargs)
post_update.send(self, request=request, response=resp)
return resp
class ListCreateAPIView(mixins.ListModelMixin,
@ -238,10 +260,16 @@ class ListCreateAPIView(mixins.ListModelMixin,
Concrete view for listing a queryset or creating a model instance.
"""
def get(self, request, *args, **kwargs):
return self.list(request, *args, **kwargs)
pre_read.send(self, request=request)
resp = self.list(request, *args, **kwargs)
post_read.send(self, request=request, response=resp)
return resp
def post(self, request, *args, **kwargs):
return self.create(request, *args, **kwargs)
pre_create.send(self, request=request)
resp = self.create(request, *args, **kwargs)
post_create.send(self, request=request, response=resp)
return resp
class RetrieveUpdateAPIView(mixins.RetrieveModelMixin,
@ -251,13 +279,22 @@ class RetrieveUpdateAPIView(mixins.RetrieveModelMixin,
Concrete view for retrieving, updating a model instance.
"""
def get(self, request, *args, **kwargs):
return self.retrieve(request, *args, **kwargs)
pre_read.send(self, request=request)
resp = self.retrieve(request, *args, **kwargs)
post_read.send(self, request=request, response=resp)
return resp
def put(self, request, *args, **kwargs):
return self.update(request, *args, **kwargs)
pre_update.send(self, request=request)
resp = self.update(request, *args, **kwargs)
post_update.send(self, request=request, response=resp)
return resp
def patch(self, request, *args, **kwargs):
return self.partial_update(request, *args, **kwargs)
pre_update.send(self, request=request)
resp = self.partial_update(request, *args, **kwargs)
post_update.send(self, request=request, response=resp)
return resp
class RetrieveDestroyAPIView(mixins.RetrieveModelMixin,
@ -267,10 +304,16 @@ class RetrieveDestroyAPIView(mixins.RetrieveModelMixin,
Concrete view for retrieving or deleting a model instance.
"""
def get(self, request, *args, **kwargs):
return self.retrieve(request, *args, **kwargs)
pre_read.send(self, request=request)
resp = self.retrieve(request, *args, **kwargs)
post_read.send(self, request=request, response=resp)
return resp
def delete(self, request, *args, **kwargs):
return self.destroy(request, *args, **kwargs)
pre_destroy.send(self, request=request)
resp = self.destroy(request, *args, **kwargs)
post_destroy.send(self, request=request, response=resp)
return resp
class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin,
@ -281,13 +324,25 @@ class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin,
Concrete view for retrieving, updating or deleting a model instance.
"""
def get(self, request, *args, **kwargs):
return self.retrieve(request, *args, **kwargs)
pre_read.send(self, request=request)
resp = self.retrieve(request, *args, **kwargs)
post_read.send(self, request=request, response=resp)
return resp
def put(self, request, *args, **kwargs):
return self.update(request, *args, **kwargs)
pre_update.send(self, request=request)
resp = self.update(request, *args, **kwargs)
post_update.send(self, request=request, response=resp)
return resp
def patch(self, request, *args, **kwargs):
return self.partial_update(request, *args, **kwargs)
pre_update.send(self, request=request)
resp = self.partial_update(request, *args, **kwargs)
post_update.send(self, request=request, response=resp)
return resp
def delete(self, request, *args, **kwargs):
return self.destroy(request, *args, **kwargs)
pre_destroy.send(self, request=request)
resp = self.destroy(request, *args, **kwargs)
post_destroy.send(self, request=request, response=resp)
return resp

10
rest_framework/signals.py Normal file
View File

@ -0,0 +1,10 @@
from django.dispatch import Signal
pre_create = Signal(providing_args=['request'])
post_create = Signal(providing_args=['request', 'response'])
pre_read = Signal(providing_args=['request'])
post_read = Signal(providing_args=['request', 'response'])
pre_update = Signal(providing_args=['request'])
post_update = Signal(providing_args=['request', 'response'])
pre_destroy = Signal(providing_args=['request'])
post_destroy = Signal(providing_args=['request', 'response'])

View File

@ -1,5 +1,7 @@
from __future__ import unicode_literals
from uuid import uuid4
import pytest
from django.db import models
from django.http import Http404
@ -9,6 +11,10 @@ from django.utils import six
from rest_framework import generics, renderers, serializers, status
from rest_framework.response import Response
from rest_framework.signals import (
post_create, post_destroy, post_read, post_update, pre_create, pre_destroy,
pre_read, pre_update
)
from rest_framework.test import APIRequestFactory
from tests.models import (
BasicModel, ForeignKeySource, ForeignKeyTarget, RESTFrameworkModel,
@ -17,6 +23,83 @@ from tests.models import (
factory = APIRequestFactory()
# Signals dict to catch signals emitted per test
signals = {
"pre_create": 0,
"post_create": 0,
"pre_read": 0,
"post_read": 0,
"pre_update": 0,
"post_update": 0,
"pre_destroy": 0,
"post_destroy": 0
}
# Signal handlers
def pre_create_handler(sender, **kwargs):
signals['pre_create'] += 1
def post_create_handler(sender, **kwargs):
signals['post_create'] += 1
def pre_read_handler(sender, **kwargs):
signals['pre_read'] += 1
def post_read_handler(sender, **kwargs):
signals['post_read'] += 1
def pre_update_handler(sender, **kwargs):
signals['pre_update'] += 1
def post_update_handler(sender, **kwargs):
signals['post_update'] += 1
def pre_destroy_handler(sender, **kwargs):
signals['pre_destroy'] += 1
def post_destroy_handler(sender, **kwargs):
signals['post_destroy'] += 1
# Connect signal handlers to the signals
pre_create.connect(pre_create_handler, dispatch_uid=uuid4().hex)
post_create.connect(post_create_handler, dispatch_uid=uuid4().hex)
pre_read.connect(pre_read_handler, dispatch_uid=uuid4().hex)
post_read.connect(post_read_handler, dispatch_uid=uuid4().hex)
pre_update.connect(pre_update_handler, dispatch_uid=uuid4().hex)
post_update.connect(post_update_handler, dispatch_uid=uuid4().hex)
pre_destroy.connect(pre_destroy_handler, dispatch_uid=uuid4().hex)
post_destroy.connect(post_destroy_handler, dispatch_uid=uuid4().hex)
# Resets all entries in the signal dict to 0 in test tearDowns
def reset_all_signals():
for x in signals.keys():
signals[x] = 0
# Test each signal has been called the correct number of times
def _test_all_signals(pre_create=0, post_create=0,
pre_read=0, post_read=0,
pre_update=0, post_update=0,
pre_destroy=0, post_destroy=0):
assert signals['pre_create'] == pre_create
assert signals['post_create'] == post_create
assert signals['pre_read'] == pre_read
assert signals['post_read'] == post_read
assert signals['pre_update'] == pre_update
assert signals['post_update'] == post_update
assert signals['pre_destroy'] == pre_destroy
assert signals['post_destroy'] == post_destroy
# Models
class SlugBasedModel(RESTFrameworkModel):
@ -92,6 +175,7 @@ class TestRootView(TestCase):
for obj in self.objects.all()
]
self.view = RootView.as_view()
reset_all_signals()
def test_get_root_view(self):
"""
@ -100,6 +184,7 @@ class TestRootView(TestCase):
request = factory.get('/')
with self.assertNumQueries(1):
response = self.view(request).render()
_test_all_signals(pre_read=1, post_read=1)
assert response.status_code == status.HTTP_200_OK
assert response.data == self.data
@ -110,6 +195,10 @@ class TestRootView(TestCase):
request = factory.head('/')
with self.assertNumQueries(1):
response = self.view(request).render()
# See here:
# https://github.com/django/django/blob/master/django/views/generic/base.py#L63
# For why this triggers get triggers
_test_all_signals(pre_read=1, post_read=1)
assert response.status_code == status.HTTP_200_OK
def test_post_root_view(self):
@ -120,6 +209,7 @@ class TestRootView(TestCase):
request = factory.post('/', data, format='json')
with self.assertNumQueries(1):
response = self.view(request).render()
_test_all_signals(pre_create=1, post_create=1)
assert response.status_code == status.HTTP_201_CREATED
assert response.data == {'id': 4, 'text': 'foobar'}
created = self.objects.get(id=4)
@ -133,6 +223,7 @@ class TestRootView(TestCase):
request = factory.put('/', data, format='json')
with self.assertNumQueries(0):
response = self.view(request).render()
_test_all_signals()
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
assert response.data == {"detail": 'Method "PUT" not allowed.'}
@ -143,6 +234,7 @@ class TestRootView(TestCase):
request = factory.delete('/')
with self.assertNumQueries(0):
response = self.view(request).render()
_test_all_signals()
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
assert response.data == {"detail": 'Method "DELETE" not allowed.'}
@ -154,6 +246,7 @@ class TestRootView(TestCase):
request = factory.post('/', data, format='json')
with self.assertNumQueries(1):
response = self.view(request).render()
_test_all_signals(pre_create=1, post_create=1)
assert response.status_code == status.HTTP_201_CREATED
assert response.data == {'id': 4, 'text': 'foobar'}
created = self.objects.get(id=4)
@ -168,6 +261,7 @@ class TestRootView(TestCase):
response = self.view(request).render()
expected_error = '<span class="help-block">Ensure this field has no more than 100 characters.</span>'
assert expected_error in response.rendered_content.decode('utf-8')
_test_all_signals(pre_create=1)
EXPECTED_QUERIES_FOR_PUT = 2
@ -188,6 +282,7 @@ class TestInstanceView(TestCase):
]
self.view = InstanceView.as_view()
self.slug_based_view = SlugBasedInstanceView.as_view()
reset_all_signals()
def test_get_instance_view(self):
"""
@ -196,6 +291,7 @@ class TestInstanceView(TestCase):
request = factory.get('/1')
with self.assertNumQueries(1):
response = self.view(request, pk=1).render()
_test_all_signals(pre_read=1, post_read=1)
assert response.status_code == status.HTTP_200_OK
assert response.data == self.data[0]
@ -207,6 +303,7 @@ class TestInstanceView(TestCase):
request = factory.post('/', data, format='json')
with self.assertNumQueries(0):
response = self.view(request).render()
_test_all_signals()
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
assert response.data == {"detail": 'Method "POST" not allowed.'}
@ -218,6 +315,7 @@ class TestInstanceView(TestCase):
request = factory.put('/1', data, format='json')
with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT):
response = self.view(request, pk='1').render()
_test_all_signals(pre_update=1, post_update=1)
assert response.status_code == status.HTTP_200_OK
assert dict(response.data) == {'id': 1, 'text': 'foobar'}
updated = self.objects.get(id=1)
@ -232,6 +330,7 @@ class TestInstanceView(TestCase):
with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT):
response = self.view(request, pk=1).render()
_test_all_signals(pre_update=1, post_update=1)
assert response.status_code == status.HTTP_200_OK
assert response.data == {'id': 1, 'text': 'foobar'}
updated = self.objects.get(id=1)
@ -244,6 +343,7 @@ class TestInstanceView(TestCase):
request = factory.delete('/1')
with self.assertNumQueries(2):
response = self.view(request, pk=1).render()
_test_all_signals(pre_destroy=1, post_destroy=1)
assert response.status_code == status.HTTP_204_NO_CONTENT
assert response.content == six.b('')
ids = [obj.id for obj in self.objects.all()]
@ -257,6 +357,7 @@ class TestInstanceView(TestCase):
request = factory.get('/a')
with self.assertNumQueries(0):
response = self.view(request, pk='a').render()
_test_all_signals(pre_read=1)
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_put_cannot_set_id(self):
@ -282,6 +383,7 @@ class TestInstanceView(TestCase):
request = factory.put('/1', data, format='json')
with self.assertNumQueries(1):
response = self.view(request, pk=1).render()
_test_all_signals(pre_update=1)
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_put_to_filtered_out_instance(self):
@ -294,6 +396,7 @@ class TestInstanceView(TestCase):
request = factory.put('/{0}'.format(filtered_out_pk), data, format='json')
response = self.view(request, pk=filtered_out_pk).render()
assert response.status_code == status.HTTP_404_NOT_FOUND
_test_all_signals(pre_update=1)
def test_patch_cannot_create_an_object(self):
"""
@ -305,6 +408,7 @@ class TestInstanceView(TestCase):
response = self.view(request, pk=999).render()
assert response.status_code == status.HTTP_404_NOT_FOUND
assert not self.objects.filter(id=999).exists()
_test_all_signals(pre_update=1)
def test_put_error_instance_view(self):
"""
@ -315,6 +419,7 @@ class TestInstanceView(TestCase):
response = self.view(request, pk=1).render()
expected_error = '<span class="help-block">Ensure this field has no more than 100 characters.</span>'
assert expected_error in response.rendered_content.decode('utf-8')
_test_all_signals(pre_update=1)
class TestFKInstanceView(TestCase):
@ -366,6 +471,7 @@ class TestOverriddenGetObject(TestCase):
return get_object_or_404(BasicModel.objects.all(), id=pk)
self.view = OverriddenGetObjectView.as_view()
reset_all_signals()
def test_overridden_get_object_view(self):
"""
@ -374,6 +480,7 @@ class TestOverriddenGetObject(TestCase):
request = factory.get('/1')
with self.assertNumQueries(1):
response = self.view(request, pk=1).render()
_test_all_signals(pre_read=1, post_read=1)
assert response.status_code == status.HTTP_200_OK
assert response.data == self.data[0]
@ -395,6 +502,7 @@ class TestCreateModelWithAutoNowAddField(TestCase):
def setUp(self):
self.objects = Comment.objects
self.view = CommentView.as_view()
reset_all_signals()
def test_create_model_with_auto_now_add_field(self):
"""
@ -408,6 +516,7 @@ class TestCreateModelWithAutoNowAddField(TestCase):
assert response.status_code == status.HTTP_201_CREATED
created = self.objects.get(id=1)
assert created.content == 'foobar'
_test_all_signals(pre_create=1, post_create=1)
# Test for particularly ugly regression with m2m in browsable API
@ -436,6 +545,9 @@ class ExampleView(generics.ListCreateAPIView):
class TestM2MBrowsableAPI(TestCase):
def setUp(self):
reset_all_signals()
def test_m2m_in_browsable_api(self):
"""
Test for particularly ugly regression with m2m in browsable API
@ -444,6 +556,7 @@ class TestM2MBrowsableAPI(TestCase):
view = ExampleView().as_view()
response = view(request).render()
assert response.status_code == status.HTTP_200_OK
_test_all_signals(pre_read=1, post_read=1)
class InclusiveFilterBackend(object):
@ -492,6 +605,7 @@ class TestFilterBackendAppliedToViews(TestCase):
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
]
reset_all_signals()
def test_get_root_view_filters_by_name_with_filter_backend(self):
"""
@ -503,6 +617,7 @@ class TestFilterBackendAppliedToViews(TestCase):
assert response.status_code == status.HTTP_200_OK
assert len(response.data) == 1
assert response.data == [{'id': 1, 'text': 'foo'}]
_test_all_signals(pre_read=1, post_read=1)
def test_get_root_view_filters_out_all_models_with_exclusive_filter_backend(self):
"""
@ -513,6 +628,7 @@ class TestFilterBackendAppliedToViews(TestCase):
response = root_view(request).render()
assert response.status_code == status.HTTP_200_OK
assert response.data == []
_test_all_signals(pre_read=1, post_read=1)
def test_get_instance_view_filters_out_name_with_filter_backend(self):
"""
@ -523,6 +639,7 @@ class TestFilterBackendAppliedToViews(TestCase):
response = instance_view(request, pk=1).render()
assert response.status_code == status.HTTP_404_NOT_FOUND
assert response.data == {'detail': 'Not found.'}
_test_all_signals(pre_read=1)
def test_get_instance_view_will_return_single_object_when_filter_does_not_exclude_it(self):
"""
@ -533,6 +650,7 @@ class TestFilterBackendAppliedToViews(TestCase):
response = instance_view(request, pk=1).render()
assert response.status_code == status.HTTP_200_OK
assert response.data == {'id': 1, 'text': 'foo'}
_test_all_signals(pre_read=1, post_read=1)
def test_dynamic_serializer_form_in_browsable_api(self):
"""
@ -544,9 +662,13 @@ class TestFilterBackendAppliedToViews(TestCase):
content = response.content.decode('utf8')
assert 'field_b' in content
assert 'field_a' not in content
_test_all_signals(pre_read=1, post_read=1)
class TestGuardedQueryset(TestCase):
def setUp(self):
reset_all_signals()
def test_guarded_queryset(self):
class QuerysetAccessError(generics.ListAPIView):
queryset = BasicModel.objects.all()
@ -556,11 +678,14 @@ class TestGuardedQueryset(TestCase):
view = QuerysetAccessError.as_view()
request = factory.get('/')
_test_all_signals()
with pytest.raises(RuntimeError):
view(request).render()
class ApiViewsTests(TestCase):
def setUp(self):
reset_all_signals()
def test_create_api_view_post(self):
class MockCreateApiView(generics.CreateAPIView):
@ -572,6 +697,7 @@ class ApiViewsTests(TestCase):
view.post('test request', 'test arg', test_kwarg='test')
assert view.called is True
assert view.call_args == data
_test_all_signals(pre_create=1, post_create=1)
def test_destroy_api_view_delete(self):
class MockDestroyApiView(generics.DestroyAPIView):
@ -583,6 +709,7 @@ class ApiViewsTests(TestCase):
view.delete('test request', 'test arg', test_kwarg='test')
assert view.called is True
assert view.call_args == data
_test_all_signals(pre_destroy=1, post_destroy=1)
def test_update_api_view_partial_update(self):
class MockUpdateApiView(generics.UpdateAPIView):
@ -594,6 +721,7 @@ class ApiViewsTests(TestCase):
view.patch('test request', 'test arg', test_kwarg='test')
assert view.called is True
assert view.call_args == data
_test_all_signals(pre_update=1, post_update=1)
def test_retrieve_update_api_view_get(self):
class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView):
@ -605,6 +733,7 @@ class ApiViewsTests(TestCase):
view.get('test request', 'test arg', test_kwarg='test')
assert view.called is True
assert view.call_args == data
_test_all_signals(pre_read=1, post_read=1)
def test_retrieve_update_api_view_put(self):
class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView):
@ -616,6 +745,7 @@ class ApiViewsTests(TestCase):
view.put('test request', 'test arg', test_kwarg='test')
assert view.called is True
assert view.call_args == data
_test_all_signals(pre_update=1, post_update=1)
def test_retrieve_update_api_view_patch(self):
class MockRetrieveUpdateApiView(generics.RetrieveUpdateAPIView):
@ -627,6 +757,7 @@ class ApiViewsTests(TestCase):
view.patch('test request', 'test arg', test_kwarg='test')
assert view.called is True
assert view.call_args == data
_test_all_signals(pre_update=1, post_update=1)
def test_retrieve_destroy_api_view_get(self):
class MockRetrieveDestroyUApiView(generics.RetrieveDestroyAPIView):
@ -638,6 +769,7 @@ class ApiViewsTests(TestCase):
view.get('test request', 'test arg', test_kwarg='test')
assert view.called is True
assert view.call_args == data
_test_all_signals(pre_read=1, post_read=1)
def test_retrieve_destroy_api_view_delete(self):
class MockRetrieveDestroyUApiView(generics.RetrieveDestroyAPIView):
@ -649,6 +781,7 @@ class ApiViewsTests(TestCase):
view.delete('test request', 'test arg', test_kwarg='test')
assert view.called is True
assert view.call_args == data
_test_all_signals(pre_destroy=1, post_destroy=1)
class GetObjectOr404Tests(TestCase):