diff --git a/rest_framework/filters.py b/rest_framework/filters.py index de91caedc..f7ad37baa 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -43,7 +43,7 @@ class DjangoFilterBackend(BaseFilterBackend): if filter_class: filter_model = filter_class.Meta.model - assert issubclass(filter_model, queryset.model), \ + assert issubclass(queryset.model, filter_model), \ 'FilterSet model %s does not match queryset model %s' % \ (filter_model, queryset.model) diff --git a/rest_framework/tests/models.py b/rest_framework/tests/models.py index 32a726c0b..0137d45a0 100644 --- a/rest_framework/tests/models.py +++ b/rest_framework/tests/models.py @@ -60,6 +60,18 @@ class ReadOnlyManyToManyModel(RESTFrameworkModel): rel = models.ManyToManyField(Anchor) +class BaseFilterableItem(RESTFrameworkModel): + text = models.CharField(max_length=100) + + class Meta: + abstract = True + + +class FilterableItem(BaseFilterableItem): + decimal = models.DecimalField(max_digits=4, decimal_places=2) + date = models.DateField() + + # Model for regression test for #285 class Comment(RESTFrameworkModel): diff --git a/rest_framework/tests/test_filters.py b/rest_framework/tests/test_filters.py index 181881865..769d34266 100644 --- a/rest_framework/tests/test_filters.py +++ b/rest_framework/tests/test_filters.py @@ -8,17 +8,12 @@ from django.utils import unittest from rest_framework import generics, serializers, status, filters from rest_framework.compat import django_filters, patterns, url from rest_framework.test import APIRequestFactory -from rest_framework.tests.models import BasicModel +from rest_framework.tests.models import (BaseFilterableItem, BasicModel, + FilterableItem) factory = APIRequestFactory() -class FilterableItem(models.Model): - text = models.CharField(max_length=100) - decimal = models.DecimalField(max_digits=4, decimal_places=2) - date = models.DateField() - - if django_filters: # Basic filter on a list view. class FilterFieldsRootView(generics.ListCreateAPIView): @@ -59,6 +54,18 @@ if django_filters: filter_class = SeveralFieldsFilter filter_backends = (filters.DjangoFilterBackend,) + # These classes are used to test base model filter support + class BaseFilterableItemFilter(django_filters.FilterSet): + text = django_filters.CharFilter() + + class Meta: + model = BaseFilterableItem + + class BaseFilterableItemFilterRootView(generics.ListCreateAPIView): + model = FilterableItem + filter_class = BaseFilterableItemFilter + filter_backends = (filters.DjangoFilterBackend,) + # Regression test for #814 class FilterableItemSerializer(serializers.ModelSerializer): class Meta: @@ -226,6 +233,18 @@ class IntegrationTestFiltering(CommonFilteringTestCase): request = factory.get('/') self.assertRaises(AssertionError, view, request) + @unittest.skipUnless(django_filters, 'django-filter not installed') + def test_base_model_filter(self): + """ + The `get_filter_class` model checks should allow base model filters. + """ + view = BaseFilterableItemFilterRootView.as_view() + + request = factory.get('/?text=aaa') + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data), 1) + @unittest.skipUnless(django_filters, 'django-filter not installed') def test_unknown_filter(self): """ @@ -612,4 +631,4 @@ class SensitiveOrderingFilterTests(TestCase): {'id': 2, username_field: 'userB'}, # PassC {'id': 3, username_field: 'userC'}, # PassA ] - ) \ No newline at end of file + ) diff --git a/rest_framework/tests/test_pagination.py b/rest_framework/tests/test_pagination.py index cadb515fa..cd2996134 100644 --- a/rest_framework/tests/test_pagination.py +++ b/rest_framework/tests/test_pagination.py @@ -8,17 +8,11 @@ from django.utils import unittest from rest_framework import generics, status, pagination, filters, serializers from rest_framework.compat import django_filters from rest_framework.test import APIRequestFactory -from rest_framework.tests.models import BasicModel +from rest_framework.tests.models import BasicModel, FilterableItem factory = APIRequestFactory() -class FilterableItem(models.Model): - text = models.CharField(max_length=100) - decimal = models.DecimalField(max_digits=4, decimal_places=2) - date = models.DateField() - - class RootView(generics.ListCreateAPIView): """ Example description for OPTIONS.