diff --git a/docs/api-guide/filtering.md b/docs/api-guide/filtering.md index 07420d842..407530c69 100644 --- a/docs/api-guide/filtering.md +++ b/docs/api-guide/filtering.md @@ -294,7 +294,7 @@ It's recommended that you explicitly specify which fields the API should allowin This helps prevent unexpected data leakage, such as allowing users to order against a password hash field or other sensitive data. -If you *don't* specify an `ordering_fields` attribute on the view, the filter class will default to allowing the user to filter on any readable fields on the serializer specified by the `serializer_class` attribute. +If you *don't* specify an `ordering_fields` attribute on the view, the filter class will default to allowing the user to filter on any readable fields on the serializer specified by the `serializer_class` attribute, or the `get_serializer` or `get_serializer_class` method. If you are confident that the queryset being used by the view doesn't contain any sensitive data, you can also explicitly specify that a view should allow ordering on *any* model field or queryset aggregate, by using the special value `'__all__'`. diff --git a/rest_framework/filters.py b/rest_framework/filters.py index de91caedc..dc62b828f 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -124,19 +124,50 @@ class OrderingFilter(BaseFilterBackend): return (ordering,) return ordering - def remove_invalid_fields(self, queryset, ordering, view): + def get_serializer(self, request, view): + """ + Returns the serializer instance that should be used for fetching the + available fields. + """ + # Default to view's get_serializer method + get_serializer = getattr(view, 'get_serializer') + if get_serializer: + serializer = get_serializer() + + else: + # Try to get the serializer class from view's method + serializer_class = getattr(view, 'get_serializer_class') + if serializer_class is None: + # Try to get the serializer class from view's attribute + serializer_class = getattr(view, 'serializer_class') + if serializer_class is None: + msg = ("Cannot use %s on a view which does not have either" + " a 'serializer_class' or 'ordering_fields' " + "attribute, or a get_serializer or " + "get_serializer_class method.") + raise ImproperlyConfigured(msg % self.__class__.__name__) + + # Extra context provided to the serializer class + serializer_context = getattr(view, 'get_serializer_context') + if serializer_context is None: + serializer_context = { + 'request': request, + 'format': view.format_kwarg, + 'view': view, + } + serializer = serializer_class(context=serializer_context) + + return serializer + + def remove_invalid_fields(self, request, queryset, ordering, view): valid_fields = getattr(view, 'ordering_fields', self.ordering_fields) if valid_fields is None: # Default to allowing filtering on serializer fields - serializer_class = getattr(view, 'serializer_class') - if serializer_class is None: - msg = ("Cannot use %s on a view which does not have either a " - "'serializer_class' or 'ordering_fields' attribute.") - raise ImproperlyConfigured(msg % self.__class__.__name__) + serializer = self.get_serializer(request, view) valid_fields = [ field.source or field_name - for field_name, field in serializer_class().fields.items() + for field_name, field in serializer.fields.items() if not getattr(field, 'write_only', False) ] elif valid_fields == '__all__': @@ -151,7 +182,7 @@ class OrderingFilter(BaseFilterBackend): if ordering: # Skip any incorrect parameters - ordering = self.remove_invalid_fields(queryset, ordering, view) + ordering = self.remove_invalid_fields(request, queryset, ordering, view) if not ordering: # Use 'ordering' attribute by default