diff --git a/rest_framework/filters.py b/rest_framework/filters.py index 2a25378a0..a510d3bf6 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -227,11 +227,14 @@ class OrderingFilter(BaseFilterBackend): if valid_fields is None: # Default to allowing filtering on serializer fields - serializer_class = getattr(view, 'serializer_class') - if serializer_class is None: + try: + serializer_class = view.get_serializer_class() + except AssertionError: # raised if no serializer_class was found msg = ("Cannot use %s on a view which does not have either a " - "'serializer_class' or 'ordering_fields' attribute.") + "'serializer_class', an overriding 'get_serializer_class' " + "or 'ordering_fields' attribute.") raise ImproperlyConfigured(msg % self.__class__.__name__) + valid_fields = [ (field.source or field_name, field.label) for field_name, field in serializer_class().fields.items() diff --git a/tests/test_filters.py b/tests/test_filters.py index b72d95691..8493c96af 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -5,6 +5,7 @@ import unittest from decimal import Decimal from django.conf.urls import url +from django.core.exceptions import ImproperlyConfigured from django.core.urlresolvers import reverse from django.db import models from django.test import TestCase @@ -754,6 +755,41 @@ class OrderingFilterTests(TestCase): self.assertContains(response, 'verbose title') + def test_ordering_with_overridden_get_serializer_class(self): + class OrderingListView(generics.ListAPIView): + queryset = OrderingFilterModel.objects.all() + filter_backends = (filters.OrderingFilter,) + ordering = ('title',) + # note: no ordering_fields and serializer_class speficied + + def get_serializer_class(self): + return OrderingFilterSerializer + + view = OrderingListView.as_view() + request = factory.get('/', {'ordering': 'text'}) + response = view(request) + self.assertEqual( + response.data, + [ + {'id': 1, 'title': 'zyx', 'text': 'abc'}, + {'id': 2, 'title': 'yxw', 'text': 'bcd'}, + {'id': 3, 'title': 'xwv', 'text': 'cde'}, + ] + ) + + def test_ordering_with_improper_configuration(self): + class OrderingListView(generics.ListAPIView): + queryset = OrderingFilterModel.objects.all() + filter_backends = (filters.OrderingFilter,) + ordering = ('title',) + # note: no ordering_fields and serializer_class + # or get_serializer_class speficied + + view = OrderingListView.as_view() + request = factory.get('/', {'ordering': 'text'}) + with self.assertRaises(ImproperlyConfigured): + view(request) + class SensitiveOrderingFilterModel(models.Model): username = models.CharField(max_length=20)