diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 9ae8cf0aa..36ecf9150 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -18,6 +18,16 @@ class GenericAPIView(views.APIView): model = None serializer_class = None model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS + filter_backend = api_settings.FILTER_BACKEND + + def filter_queryset(self, queryset): + """ + Given a queryset, filter it with whichever filter backend is in use. + """ + if not self.filter_backend: + return queryset + backend = self.filter_backend() + return backend.filter_queryset(self.request, queryset, self) def get_serializer_context(self): """ @@ -81,16 +91,6 @@ class MultipleObjectAPIView(MultipleObjectMixin, GenericAPIView): paginate_by = api_settings.PAGINATE_BY paginate_by_param = api_settings.PAGINATE_BY_PARAM pagination_serializer_class = api_settings.DEFAULT_PAGINATION_SERIALIZER_CLASS - filter_backend = api_settings.FILTER_BACKEND - - def filter_queryset(self, queryset): - """ - Given a queryset, filter it with whichever filter backend is in use. - """ - if not self.filter_backend: - return queryset - backend = self.filter_backend() - return backend.filter_queryset(self.request, queryset, self) def get_pagination_serializer(self, page=None): """ diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 97201c4b1..8e4012049 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -97,7 +97,9 @@ class RetrieveModelMixin(object): Should be mixed in with `SingleObjectAPIView`. """ def retrieve(self, request, *args, **kwargs): - self.object = self.get_object() + queryset = self.get_queryset() + filtered_queryset = self.filter_queryset(queryset) + self.object = self.get_object(filtered_queryset) serializer = self.get_serializer(self.object) return Response(serializer.data) diff --git a/rest_framework/tests/generics.py b/rest_framework/tests/generics.py index f8f2ddaac..f70934016 100644 --- a/rest_framework/tests/generics.py +++ b/rest_framework/tests/generics.py @@ -350,3 +350,78 @@ class TestM2MBrowseableAPI(TestCase): view = ExampleView().as_view() response = view(request).render() self.assertEqual(response.status_code, status.HTTP_200_OK) + + +class InclusiveFilterBackend(object): + def filter_queryset(self, request, queryset, view): + return queryset.filter(text='foo') + + +class ExclusiveFilterBackend(object): + def filter_queryset(self, request, queryset, view): + return queryset.filter(text='other') + + +class TestFilterBackendAppliedToViews(TestCase): + + def setUp(self): + """ + Create 3 BasicModel instances to filter on. + """ + items = ['foo', 'bar', 'baz'] + for item in items: + BasicModel(text=item).save() + self.objects = BasicModel.objects + self.data = [ + {'id': obj.id, 'text': obj.text} + for obj in self.objects.all() + ] + self.root_view = RootView.as_view() + self.instance_view = InstanceView.as_view() + self.original_root_backend = getattr(RootView, 'filter_backend') + self.original_instance_backend = getattr(InstanceView, 'filter_backend') + + def tearDown(self): + setattr(RootView, 'filter_backend', self.original_root_backend) + setattr(InstanceView, 'filter_backend', self.original_instance_backend) + + def test_get_root_view_filters_by_name_with_filter_backend(self): + """ + GET requests to ListCreateAPIView should return filtered list. + """ + setattr(RootView, 'filter_backend', InclusiveFilterBackend) + request = factory.get('/') + response = self.root_view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data), 1) + self.assertEqual(response.data, [{'id': 1, 'text': 'foo'}]) + + def test_get_root_view_filters_out_all_models_with_exclusive_filter_backend(self): + """ + GET requests to ListCreateAPIView should return empty list when all models are filtered out. + """ + setattr(RootView, 'filter_backend', ExclusiveFilterBackend) + request = factory.get('/') + response = self.root_view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, []) + + def test_get_instance_view_filters_out_name_with_filter_backend(self): + """ + GET requests to RetrieveUpdateDestroyAPIView should raise 404 when model filtered out. + """ + setattr(InstanceView, 'filter_backend', ExclusiveFilterBackend) + request = factory.get('/1') + response = self.instance_view(request, pk=1).render() + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + self.assertEqual(response.data, {'detail': 'Not found'}) + + def test_get_instance_view_will_return_single_object_when_filter_does_not_exclude_it(self): + """ + GET requests to RetrieveUpdateDestroyAPIView should return a single object when not excluded + """ + setattr(InstanceView, 'filter_backend', InclusiveFilterBackend) + request = factory.get('/1') + response = self.instance_view(request, pk=1).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, {'id': 1, 'text': 'foo'})