This commit is contained in:
Benjamin Mampaey 2016-05-16 01:39:45 +00:00
commit c7de292901
2 changed files with 51 additions and 2 deletions

View File

@ -227,8 +227,9 @@ class OrderingFilter(BaseFilterBackend):
if valid_fields is None: if valid_fields is None:
# Default to allowing filtering on serializer fields # Default to allowing filtering on serializer fields
serializer_class = getattr(view, 'serializer_class') try:
if serializer_class is None: serializer_class = view.get_serializer_class()
except AssertionError:
msg = ("Cannot use %s on a view which does not have either a " msg = ("Cannot use %s on a view which does not have either a "
"'serializer_class' or 'ordering_fields' attribute.") "'serializer_class' or 'ordering_fields' attribute.")
raise ImproperlyConfigured(msg % self.__class__.__name__) raise ImproperlyConfigured(msg % self.__class__.__name__)

View File

@ -5,6 +5,7 @@ import unittest
from decimal import Decimal from decimal import Decimal
from django.conf.urls import url from django.conf.urls import url
from django.core.exceptions import ImproperlyConfigured
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
from django.db import models from django.db import models
from django.test import TestCase from django.test import TestCase
@ -585,6 +586,53 @@ class OrderingFilterTests(TestCase):
) )
OrderingFilterModel(title=title, text=text).save() OrderingFilterModel(title=title, text=text).save()
def test_get_valid_fields_from_explicit_serializer_class(self):
class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter, )
ordering_fields = None
queryset = OrderingFilterModel.objects.all()
view = OrderingListView()
backend = filters.OrderingFilter()
valid_fields = backend.get_valid_fields(queryset, view)
expected_valid_fields = [(field.source or field_name, field.label)
for (field_name, field) in OrderingFilterSerializer().fields.items()
if not getattr(field, 'write_only', False) and not field.source == '*']
self.assertEqual(valid_fields, expected_valid_fields)
def test_get_valid_fields_from_explicit_get_serializer_class(self):
class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all()
filter_backends = (filters.OrderingFilter,)
ordering_fields = None
def get_serializer_class(self):
return OrderingFilterSerializer
queryset = OrderingFilterModel.objects.all()
view = OrderingListView()
backend = filters.OrderingFilter()
valid_fields = backend.get_valid_fields(queryset, view)
expected_valid_fields = [(field.source or field_name, field.label)
for (field_name, field) in OrderingFilterSerializer().fields.items()
if not getattr(field, 'write_only', False) and not field.source == '*']
self.assertEqual(valid_fields, expected_valid_fields)
def test_improperly_configured_error_from_get_valid_fields(self):
class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all()
filter_backends = (filters.OrderingFilter,)
ordering_fields = None
serializer_class = None
queryset = OrderingFilterModel.objects.all()
view = OrderingListView()
backend = filters.OrderingFilter()
with self.assertRaises(ImproperlyConfigured):
backend.get_valid_fields(queryset, view)
def test_ordering(self): def test_ordering(self):
class OrderingListView(generics.ListAPIView): class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all() queryset = OrderingFilterModel.objects.all()