From df1fc8df7d27ea3caee0ce3474a2fb5549ba9b45 Mon Sep 17 00:00:00 2001 From: Benjamin Mampaey Date: Thu, 25 Feb 2016 12:11:22 +0100 Subject: [PATCH] Fixed issue 3957 where serializer_class wass accessed directly on the view in OrderingFilter backend instead of calling the view's get_serializer_class method --- rest_framework/filters.py | 5 ++-- tests/test_filters.py | 48 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/rest_framework/filters.py b/rest_framework/filters.py index 42e77d910..8d287b1ce 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -227,8 +227,9 @@ class OrderingFilter(BaseFilterBackend): if valid_fields is None: # Default to allowing filtering on serializer fields - serializer_class = getattr(view, 'serializer_class') - if serializer_class is None: + try: + serializer_class = view.get_serializer_class() + except AssertionError: msg = ("Cannot use %s on a view which does not have either a " "'serializer_class' or 'ordering_fields' attribute.") raise ImproperlyConfigured(msg % self.__class__.__name__) diff --git a/tests/test_filters.py b/tests/test_filters.py index 729a7b75b..9809c9bfe 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -5,6 +5,7 @@ import unittest from decimal import Decimal from django.conf.urls import url +from django.core.exceptions import ImproperlyConfigured from django.core.urlresolvers import reverse from django.db import models from django.test import TestCase @@ -585,6 +586,53 @@ class OrderingFilterTests(TestCase): ) OrderingFilterModel(title=title, text=text).save() + def test_get_valid_fields_from_explicit_serializer_class(self): + class OrderingListView(generics.ListAPIView): + queryset = OrderingFilterModel.objects.all() + serializer_class = OrderingFilterSerializer + filter_backends = (filters.OrderingFilter, ) + ordering_fields = None + + queryset = OrderingFilterModel.objects.all() + view = OrderingListView() + backend = filters.OrderingFilter() + valid_fields = backend.get_valid_fields(queryset, view) + expected_valid_fields = [(field.source or field_name, field.label) + for (field_name, field) in OrderingFilterSerializer().fields.items() + if not getattr(field, 'write_only', False) and not field.source == '*'] + self.assertEqual(valid_fields, expected_valid_fields) + + def test_get_valid_fields_from_explicit_get_serializer_class(self): + class OrderingListView(generics.ListAPIView): + queryset = OrderingFilterModel.objects.all() + filter_backends = (filters.OrderingFilter,) + ordering_fields = None + + def get_serializer_class(self): + return OrderingFilterSerializer + + queryset = OrderingFilterModel.objects.all() + view = OrderingListView() + backend = filters.OrderingFilter() + valid_fields = backend.get_valid_fields(queryset, view) + expected_valid_fields = [(field.source or field_name, field.label) + for (field_name, field) in OrderingFilterSerializer().fields.items() + if not getattr(field, 'write_only', False) and not field.source == '*'] + self.assertEqual(valid_fields, expected_valid_fields) + + def test_improperly_configured_error_from_get_valid_fields(self): + class OrderingListView(generics.ListAPIView): + queryset = OrderingFilterModel.objects.all() + filter_backends = (filters.OrderingFilter,) + ordering_fields = None + serializer_class = None + + queryset = OrderingFilterModel.objects.all() + view = OrderingListView() + backend = filters.OrderingFilter() + with self.assertRaises(ImproperlyConfigured): + backend.get_valid_fields(queryset, view) + def test_ordering(self): class OrderingListView(generics.ListAPIView): queryset = OrderingFilterModel.objects.all()