mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-01-23 15:54:16 +03:00
Allow filter model to be a subclass of the queryset one.
This commit is contained in:
parent
00b1877106
commit
4d45865bd7
|
@ -43,7 +43,7 @@ class DjangoFilterBackend(BaseFilterBackend):
|
|||
if filter_class:
|
||||
filter_model = filter_class.Meta.model
|
||||
|
||||
assert issubclass(filter_model, queryset.model), \
|
||||
assert issubclass(queryset.model, filter_model), \
|
||||
'FilterSet model %s does not match queryset model %s' % \
|
||||
(filter_model, queryset.model)
|
||||
|
||||
|
|
|
@ -60,6 +60,18 @@ class ReadOnlyManyToManyModel(RESTFrameworkModel):
|
|||
rel = models.ManyToManyField(Anchor)
|
||||
|
||||
|
||||
class BaseFilterableItem(RESTFrameworkModel):
|
||||
text = models.CharField(max_length=100)
|
||||
|
||||
class Meta:
|
||||
abstract = True
|
||||
|
||||
|
||||
class FilterableItem(BaseFilterableItem):
|
||||
decimal = models.DecimalField(max_digits=4, decimal_places=2)
|
||||
date = models.DateField()
|
||||
|
||||
|
||||
# Model for regression test for #285
|
||||
|
||||
class Comment(RESTFrameworkModel):
|
||||
|
|
|
@ -8,17 +8,12 @@ from django.utils import unittest
|
|||
from rest_framework import generics, serializers, status, filters
|
||||
from rest_framework.compat import django_filters, patterns, url
|
||||
from rest_framework.test import APIRequestFactory
|
||||
from rest_framework.tests.models import BasicModel
|
||||
from rest_framework.tests.models import (BaseFilterableItem, BasicModel,
|
||||
FilterableItem)
|
||||
|
||||
factory = APIRequestFactory()
|
||||
|
||||
|
||||
class FilterableItem(models.Model):
|
||||
text = models.CharField(max_length=100)
|
||||
decimal = models.DecimalField(max_digits=4, decimal_places=2)
|
||||
date = models.DateField()
|
||||
|
||||
|
||||
if django_filters:
|
||||
# Basic filter on a list view.
|
||||
class FilterFieldsRootView(generics.ListCreateAPIView):
|
||||
|
@ -59,6 +54,18 @@ if django_filters:
|
|||
filter_class = SeveralFieldsFilter
|
||||
filter_backends = (filters.DjangoFilterBackend,)
|
||||
|
||||
# These classes are used to test base model filter support
|
||||
class BaseFilterableItemFilter(django_filters.FilterSet):
|
||||
text = django_filters.CharFilter()
|
||||
|
||||
class Meta:
|
||||
model = BaseFilterableItem
|
||||
|
||||
class BaseFilterableItemFilterRootView(generics.ListCreateAPIView):
|
||||
model = FilterableItem
|
||||
filter_class = BaseFilterableItemFilter
|
||||
filter_backends = (filters.DjangoFilterBackend,)
|
||||
|
||||
# Regression test for #814
|
||||
class FilterableItemSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
|
@ -226,6 +233,18 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
|
|||
request = factory.get('/')
|
||||
self.assertRaises(AssertionError, view, request)
|
||||
|
||||
@unittest.skipUnless(django_filters, 'django-filter not installed')
|
||||
def test_base_model_filter(self):
|
||||
"""
|
||||
The `get_filter_class` model checks should allow base model filters.
|
||||
"""
|
||||
view = BaseFilterableItemFilterRootView.as_view()
|
||||
|
||||
request = factory.get('/?text=aaa')
|
||||
response = view(request).render()
|
||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||
self.assertEqual(len(response.data), 1)
|
||||
|
||||
@unittest.skipUnless(django_filters, 'django-filter not installed')
|
||||
def test_unknown_filter(self):
|
||||
"""
|
||||
|
@ -612,4 +631,4 @@ class SensitiveOrderingFilterTests(TestCase):
|
|||
{'id': 2, username_field: 'userB'}, # PassC
|
||||
{'id': 3, username_field: 'userC'}, # PassA
|
||||
]
|
||||
)
|
||||
)
|
||||
|
|
|
@ -8,17 +8,11 @@ from django.utils import unittest
|
|||
from rest_framework import generics, status, pagination, filters, serializers
|
||||
from rest_framework.compat import django_filters
|
||||
from rest_framework.test import APIRequestFactory
|
||||
from rest_framework.tests.models import BasicModel
|
||||
from rest_framework.tests.models import BasicModel, FilterableItem
|
||||
|
||||
factory = APIRequestFactory()
|
||||
|
||||
|
||||
class FilterableItem(models.Model):
|
||||
text = models.CharField(max_length=100)
|
||||
decimal = models.DecimalField(max_digits=4, decimal_places=2)
|
||||
date = models.DateField()
|
||||
|
||||
|
||||
class RootView(generics.ListCreateAPIView):
|
||||
"""
|
||||
Example description for OPTIONS.
|
||||
|
|
Loading…
Reference in New Issue
Block a user