From fe2aede18d33436fd9e2c8858cd6946450d2e82f Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 1 Jun 2016 11:08:04 +0100 Subject: [PATCH] More robust default behavior on OrderingFilter (#4156) --- rest_framework/filters.py | 40 +++++++++++++++++++++++++++------------ 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/rest_framework/filters.py b/rest_framework/filters.py index 532cb053a..3836e8170 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -222,24 +222,40 @@ class OrderingFilter(BaseFilterBackend): return (ordering,) return ordering + def get_default_valid_fields(self, queryset, view): + # If `ordering_fields` is not specified, then we determine a default + # based on the serializer class, if one exists on the view. + if hasattr(view, 'get_serializer_class'): + try: + serializer_class = view.get_serializer_class() + except AssertionError: + # Raised by the default implementation if + # no serializer_class was found + serializer_class = None + else: + serializer_class = getattr(view, 'serializer_class', None) + + if serializer_class is None: + msg = ( + "Cannot use %s on a view which does not have either a " + "'serializer_class', an overriding 'get_serializer_class' " + "or 'ordering_fields' attribute." + ) + raise ImproperlyConfigured(msg % self.__class__.__name__) + + return [ + (field.source or field_name, field.label) + for field_name, field in serializer_class().fields.items() + if not getattr(field, 'write_only', False) and not field.source == '*' + ] + def get_valid_fields(self, queryset, view): valid_fields = getattr(view, 'ordering_fields', self.ordering_fields) if valid_fields is None: # Default to allowing filtering on serializer fields - try: - serializer_class = view.get_serializer_class() - except AssertionError: # raised if no serializer_class was found - msg = ("Cannot use %s on a view which does not have either a " - "'serializer_class', an overriding 'get_serializer_class' " - "or 'ordering_fields' attribute.") - raise ImproperlyConfigured(msg % self.__class__.__name__) + return self.get_default_valid_fields(queryset, view) - valid_fields = [ - (field.source or field_name, field.label) - for field_name, field in serializer_class().fields.items() - if not getattr(field, 'write_only', False) and not field.source == '*' - ] elif valid_fields == '__all__': # View explicitly allows filtering on any model field valid_fields = [