diff --git a/rest_framework/filters.py b/rest_framework/filters.py index 96d15eb9d..57a616c24 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -44,7 +44,7 @@ class DjangoFilterBackend(BaseFilterBackend): if filter_class: filter_model = filter_class.Meta.model - assert issubclass(filter_model, queryset.model), \ + assert issubclass(queryset.model, filter_model), \ 'FilterSet model %s does not match queryset model %s' % \ (filter_model, queryset.model) diff --git a/tests/models.py b/tests/models.py index fba3f8f7c..e378c1cfe 100644 --- a/tests/models.py +++ b/tests/models.py @@ -60,6 +60,18 @@ class ReadOnlyManyToManyModel(RESTFrameworkModel): rel = models.ManyToManyField(Anchor) +class BaseFilterableItem(RESTFrameworkModel): + text = models.CharField(max_length=100) + + class Meta: + abstract = True + + +class FilterableItem(BaseFilterableItem): + decimal = models.DecimalField(max_digits=4, decimal_places=2) + date = models.DateField() + + # Model for regression test for #285 class Comment(RESTFrameworkModel): diff --git a/tests/test_filters.py b/tests/test_filters.py index 85840e018..38ddf4e43 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -9,7 +9,7 @@ from django.conf.urls import patterns, url from rest_framework import generics, serializers, status, filters from rest_framework.compat import django_filters from rest_framework.test import APIRequestFactory -from .models import FilterableItem, BasicModel +from .models import BaseFilterableItem, FilterableItem, BasicModel from .utils import temporary_setting factory = APIRequestFactory() @@ -55,6 +55,18 @@ if django_filters: filter_class = SeveralFieldsFilter filter_backends = (filters.DjangoFilterBackend,) + # These classes are used to test base model filter support + class BaseFilterableItemFilter(django_filters.FilterSet): + text = django_filters.CharFilter() + + class Meta: + model = BaseFilterableItem + + class BaseFilterableItemFilterRootView(generics.ListCreateAPIView): + model = FilterableItem + filter_class = BaseFilterableItemFilter + filter_backends = (filters.DjangoFilterBackend,) + # Regression test for #814 class FilterableItemSerializer(serializers.ModelSerializer): class Meta: @@ -225,6 +237,18 @@ class IntegrationTestFiltering(CommonFilteringTestCase): request = factory.get('/') self.assertRaises(AssertionError, view, request) + @unittest.skipUnless(django_filters, 'django-filter not installed') + def test_base_model_filter(self): + """ + The `get_filter_class` model checks should allow base model filters. + """ + view = BaseFilterableItemFilterRootView.as_view() + + request = factory.get('/?text=aaa') + response = view(request).render() + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data), 1) + @unittest.skipUnless(django_filters, 'django-filter not installed') def test_unknown_filter(self): """