From de69a28b9e786b8c759cda4acedb0a1b8542298b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 8 May 2013 20:18:01 +0100 Subject: [PATCH] Test and fix for #814. --- rest_framework/filters.py | 14 ++++++++++---- rest_framework/tests/filterset.py | 28 +++++++++++++++++++++++++++- 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/rest_framework/filters.py b/rest_framework/filters.py index 571704dc9..f2163f6fb 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -38,21 +38,27 @@ class DjangoFilterBackend(BaseFilterBackend): """ filter_class = getattr(view, 'filter_class', None) filter_fields = getattr(view, 'filter_fields', None) - view_model = getattr(view, 'model', None) + model_cls = getattr(view, 'model', None) + queryset = getattr(view, 'queryset', None) + if model_cls is None and queryset is not None: + model_cls = queryset.model if filter_class: filter_model = filter_class.Meta.model - assert issubclass(filter_model, view_model), \ + assert issubclass(filter_model, model_cls), \ 'FilterSet model %s does not match view model %s' % \ - (filter_model, view_model) + (filter_model, model_cls) return filter_class if filter_fields: + assert model_cls is not None, 'Cannot use DjangoFilterBackend ' \ + 'on a view which does not have a .model or .queryset attribute.' + class AutoFilterSet(self.default_filter_set): class Meta: - model = view_model + model = model_cls fields = filter_fields return AutoFilterSet diff --git a/rest_framework/tests/filterset.py b/rest_framework/tests/filterset.py index 1e53a5cdb..023bd0166 100644 --- a/rest_framework/tests/filterset.py +++ b/rest_framework/tests/filterset.py @@ -5,7 +5,7 @@ from django.core.urlresolvers import reverse from django.test import TestCase from django.test.client import RequestFactory from django.utils import unittest -from rest_framework import generics, status, filters +from rest_framework import generics, serializers, status, filters from rest_framework.compat import django_filters, patterns, url from rest_framework.tests.models import FilterableItem, BasicModel @@ -52,6 +52,17 @@ if django_filters: filter_class = SeveralFieldsFilter filter_backend = filters.DjangoFilterBackend + # Regression test for #814 + class FilterableItemSerializer(serializers.ModelSerializer): + class Meta: + model = FilterableItem + + class FilterFieldsQuerysetView(generics.ListCreateAPIView): + queryset = FilterableItem.objects.all() + serializer_class = FilterableItemSerializer + filter_fields = ['decimal', 'date'] + filter_backend = filters.DjangoFilterBackend + urlpatterns = patterns('', url(r'^(?P\d+)/$', FilterClassDetailView.as_view(), name='detail-view'), url(r'^$', FilterClassRootView.as_view(), name='root-view'), @@ -114,6 +125,21 @@ class IntegrationTestFiltering(CommonFilteringTestCase): expected_data = [f for f in self.data if f['date'] == search_date] self.assertEqual(response.data, expected_data) + @unittest.skipUnless(django_filters, 'django-filters not installed') + def test_filter_with_queryset(self): + """ + Regression test for #814. + """ + view = FilterFieldsQuerysetView.as_view() + + # Tests that the decimal filter works. + search_decimal = Decimal('2.25') + request = factory.get('/?decimal=%s' % search_decimal) + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + expected_data = [f for f in self.data if f['decimal'] == search_decimal] + self.assertEqual(response.data, expected_data) + @unittest.skipUnless(django_filters, 'django-filters not installed') def test_get_filtered_class_root_view(self): """