mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-02-02 20:54:42 +03:00
Test and fix for #814.
This commit is contained in:
parent
9d59e55cec
commit
de69a28b9e
|
@ -38,21 +38,27 @@ class DjangoFilterBackend(BaseFilterBackend):
|
||||||
"""
|
"""
|
||||||
filter_class = getattr(view, 'filter_class', None)
|
filter_class = getattr(view, 'filter_class', None)
|
||||||
filter_fields = getattr(view, 'filter_fields', None)
|
filter_fields = getattr(view, 'filter_fields', None)
|
||||||
view_model = getattr(view, 'model', None)
|
model_cls = getattr(view, 'model', None)
|
||||||
|
queryset = getattr(view, 'queryset', None)
|
||||||
|
if model_cls is None and queryset is not None:
|
||||||
|
model_cls = queryset.model
|
||||||
|
|
||||||
if filter_class:
|
if filter_class:
|
||||||
filter_model = filter_class.Meta.model
|
filter_model = filter_class.Meta.model
|
||||||
|
|
||||||
assert issubclass(filter_model, view_model), \
|
assert issubclass(filter_model, model_cls), \
|
||||||
'FilterSet model %s does not match view model %s' % \
|
'FilterSet model %s does not match view model %s' % \
|
||||||
(filter_model, view_model)
|
(filter_model, model_cls)
|
||||||
|
|
||||||
return filter_class
|
return filter_class
|
||||||
|
|
||||||
if filter_fields:
|
if filter_fields:
|
||||||
|
assert model_cls is not None, 'Cannot use DjangoFilterBackend ' \
|
||||||
|
'on a view which does not have a .model or .queryset attribute.'
|
||||||
|
|
||||||
class AutoFilterSet(self.default_filter_set):
|
class AutoFilterSet(self.default_filter_set):
|
||||||
class Meta:
|
class Meta:
|
||||||
model = view_model
|
model = model_cls
|
||||||
fields = filter_fields
|
fields = filter_fields
|
||||||
return AutoFilterSet
|
return AutoFilterSet
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ from django.core.urlresolvers import reverse
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
from django.test.client import RequestFactory
|
from django.test.client import RequestFactory
|
||||||
from django.utils import unittest
|
from django.utils import unittest
|
||||||
from rest_framework import generics, status, filters
|
from rest_framework import generics, serializers, status, filters
|
||||||
from rest_framework.compat import django_filters, patterns, url
|
from rest_framework.compat import django_filters, patterns, url
|
||||||
from rest_framework.tests.models import FilterableItem, BasicModel
|
from rest_framework.tests.models import FilterableItem, BasicModel
|
||||||
|
|
||||||
|
@ -52,6 +52,17 @@ if django_filters:
|
||||||
filter_class = SeveralFieldsFilter
|
filter_class = SeveralFieldsFilter
|
||||||
filter_backend = filters.DjangoFilterBackend
|
filter_backend = filters.DjangoFilterBackend
|
||||||
|
|
||||||
|
# Regression test for #814
|
||||||
|
class FilterableItemSerializer(serializers.ModelSerializer):
|
||||||
|
class Meta:
|
||||||
|
model = FilterableItem
|
||||||
|
|
||||||
|
class FilterFieldsQuerysetView(generics.ListCreateAPIView):
|
||||||
|
queryset = FilterableItem.objects.all()
|
||||||
|
serializer_class = FilterableItemSerializer
|
||||||
|
filter_fields = ['decimal', 'date']
|
||||||
|
filter_backend = filters.DjangoFilterBackend
|
||||||
|
|
||||||
urlpatterns = patterns('',
|
urlpatterns = patterns('',
|
||||||
url(r'^(?P<pk>\d+)/$', FilterClassDetailView.as_view(), name='detail-view'),
|
url(r'^(?P<pk>\d+)/$', FilterClassDetailView.as_view(), name='detail-view'),
|
||||||
url(r'^$', FilterClassRootView.as_view(), name='root-view'),
|
url(r'^$', FilterClassRootView.as_view(), name='root-view'),
|
||||||
|
@ -114,6 +125,21 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
|
||||||
expected_data = [f for f in self.data if f['date'] == search_date]
|
expected_data = [f for f in self.data if f['date'] == search_date]
|
||||||
self.assertEqual(response.data, expected_data)
|
self.assertEqual(response.data, expected_data)
|
||||||
|
|
||||||
|
@unittest.skipUnless(django_filters, 'django-filters not installed')
|
||||||
|
def test_filter_with_queryset(self):
|
||||||
|
"""
|
||||||
|
Regression test for #814.
|
||||||
|
"""
|
||||||
|
view = FilterFieldsQuerysetView.as_view()
|
||||||
|
|
||||||
|
# Tests that the decimal filter works.
|
||||||
|
search_decimal = Decimal('2.25')
|
||||||
|
request = factory.get('/?decimal=%s' % search_decimal)
|
||||||
|
response = view(request).render()
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
expected_data = [f for f in self.data if f['decimal'] == search_decimal]
|
||||||
|
self.assertEqual(response.data, expected_data)
|
||||||
|
|
||||||
@unittest.skipUnless(django_filters, 'django-filters not installed')
|
@unittest.skipUnless(django_filters, 'django-filters not installed')
|
||||||
def test_get_filtered_class_root_view(self):
|
def test_get_filtered_class_root_view(self):
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue
Block a user