mirror of
https://github.com/encode/django-rest-framework.git
synced 2024-11-10 19:56:59 +03:00
More robust default behavior on OrderingFilter (#4156)
This commit is contained in:
parent
dc09eef24a
commit
fe2aede18d
|
@ -222,24 +222,40 @@ class OrderingFilter(BaseFilterBackend):
|
|||
return (ordering,)
|
||||
return ordering
|
||||
|
||||
def get_default_valid_fields(self, queryset, view):
|
||||
# If `ordering_fields` is not specified, then we determine a default
|
||||
# based on the serializer class, if one exists on the view.
|
||||
if hasattr(view, 'get_serializer_class'):
|
||||
try:
|
||||
serializer_class = view.get_serializer_class()
|
||||
except AssertionError:
|
||||
# Raised by the default implementation if
|
||||
# no serializer_class was found
|
||||
serializer_class = None
|
||||
else:
|
||||
serializer_class = getattr(view, 'serializer_class', None)
|
||||
|
||||
if serializer_class is None:
|
||||
msg = (
|
||||
"Cannot use %s on a view which does not have either a "
|
||||
"'serializer_class', an overriding 'get_serializer_class' "
|
||||
"or 'ordering_fields' attribute."
|
||||
)
|
||||
raise ImproperlyConfigured(msg % self.__class__.__name__)
|
||||
|
||||
return [
|
||||
(field.source or field_name, field.label)
|
||||
for field_name, field in serializer_class().fields.items()
|
||||
if not getattr(field, 'write_only', False) and not field.source == '*'
|
||||
]
|
||||
|
||||
def get_valid_fields(self, queryset, view):
|
||||
valid_fields = getattr(view, 'ordering_fields', self.ordering_fields)
|
||||
|
||||
if valid_fields is None:
|
||||
# Default to allowing filtering on serializer fields
|
||||
try:
|
||||
serializer_class = view.get_serializer_class()
|
||||
except AssertionError: # raised if no serializer_class was found
|
||||
msg = ("Cannot use %s on a view which does not have either a "
|
||||
"'serializer_class', an overriding 'get_serializer_class' "
|
||||
"or 'ordering_fields' attribute.")
|
||||
raise ImproperlyConfigured(msg % self.__class__.__name__)
|
||||
return self.get_default_valid_fields(queryset, view)
|
||||
|
||||
valid_fields = [
|
||||
(field.source or field_name, field.label)
|
||||
for field_name, field in serializer_class().fields.items()
|
||||
if not getattr(field, 'write_only', False) and not field.source == '*'
|
||||
]
|
||||
elif valid_fields == '__all__':
|
||||
# View explicitly allows filtering on any model field
|
||||
valid_fields = [
|
||||
|
|
Loading…
Reference in New Issue
Block a user