Allow filter model to be a subclass of the queryset one.

This commit is contained in:
Simon Charette 2014-02-09 00:50:03 -05:00
parent 00b1877106
commit 4d45865bd7
4 changed files with 41 additions and 16 deletions

View File

@ -43,7 +43,7 @@ class DjangoFilterBackend(BaseFilterBackend):
if filter_class:
filter_model = filter_class.Meta.model
assert issubclass(filter_model, queryset.model), \
assert issubclass(queryset.model, filter_model), \
'FilterSet model %s does not match queryset model %s' % \
(filter_model, queryset.model)

View File

@ -60,6 +60,18 @@ class ReadOnlyManyToManyModel(RESTFrameworkModel):
rel = models.ManyToManyField(Anchor)
class BaseFilterableItem(RESTFrameworkModel):
text = models.CharField(max_length=100)
class Meta:
abstract = True
class FilterableItem(BaseFilterableItem):
decimal = models.DecimalField(max_digits=4, decimal_places=2)
date = models.DateField()
# Model for regression test for #285
class Comment(RESTFrameworkModel):

View File

@ -8,17 +8,12 @@ from django.utils import unittest
from rest_framework import generics, serializers, status, filters
from rest_framework.compat import django_filters, patterns, url
from rest_framework.test import APIRequestFactory
from rest_framework.tests.models import BasicModel
from rest_framework.tests.models import (BaseFilterableItem, BasicModel,
FilterableItem)
factory = APIRequestFactory()
class FilterableItem(models.Model):
text = models.CharField(max_length=100)
decimal = models.DecimalField(max_digits=4, decimal_places=2)
date = models.DateField()
if django_filters:
# Basic filter on a list view.
class FilterFieldsRootView(generics.ListCreateAPIView):
@ -59,6 +54,18 @@ if django_filters:
filter_class = SeveralFieldsFilter
filter_backends = (filters.DjangoFilterBackend,)
# These classes are used to test base model filter support
class BaseFilterableItemFilter(django_filters.FilterSet):
text = django_filters.CharFilter()
class Meta:
model = BaseFilterableItem
class BaseFilterableItemFilterRootView(generics.ListCreateAPIView):
model = FilterableItem
filter_class = BaseFilterableItemFilter
filter_backends = (filters.DjangoFilterBackend,)
# Regression test for #814
class FilterableItemSerializer(serializers.ModelSerializer):
class Meta:
@ -226,6 +233,18 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
request = factory.get('/')
self.assertRaises(AssertionError, view, request)
@unittest.skipUnless(django_filters, 'django-filter not installed')
def test_base_model_filter(self):
"""
The `get_filter_class` model checks should allow base model filters.
"""
view = BaseFilterableItemFilterRootView.as_view()
request = factory.get('/?text=aaa')
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data), 1)
@unittest.skipUnless(django_filters, 'django-filter not installed')
def test_unknown_filter(self):
"""
@ -612,4 +631,4 @@ class SensitiveOrderingFilterTests(TestCase):
{'id': 2, username_field: 'userB'}, # PassC
{'id': 3, username_field: 'userC'}, # PassA
]
)
)

View File

@ -8,17 +8,11 @@ from django.utils import unittest
from rest_framework import generics, status, pagination, filters, serializers
from rest_framework.compat import django_filters
from rest_framework.test import APIRequestFactory
from rest_framework.tests.models import BasicModel
from rest_framework.tests.models import BasicModel, FilterableItem
factory = APIRequestFactory()
class FilterableItem(models.Model):
text = models.CharField(max_length=100)
decimal = models.DecimalField(max_digits=4, decimal_places=2)
date = models.DateField()
class RootView(generics.ListCreateAPIView):
"""
Example description for OPTIONS.