mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-07-18 04:02:35 +03:00
use better solution and add test
This commit is contained in:
parent
226a37a761
commit
0561ecff33
|
@ -237,29 +237,13 @@ class OrderingFilter(BaseFilterBackend):
|
|||
params = request.query_params.get(self.ordering_param)
|
||||
if params:
|
||||
fields = [param.strip() for param in params.split(',')]
|
||||
valid_filed_names = self.remove_invalid_fields(queryset, fields, view, request)
|
||||
ordering = self.convert_to_origin_filed_name(request, queryset, view, valid_filed_names)
|
||||
ordering = self.remove_invalid_fields(queryset, fields, view, request)
|
||||
if ordering:
|
||||
return ordering
|
||||
|
||||
# No ordering was included, or all the ordering fields were invalid
|
||||
return self.get_default_ordering(view)
|
||||
|
||||
def convert_to_origin_filed_name(self, request, queryset, view, ordering):
|
||||
valid_fields = getattr(view, 'ordering_fields', self.ordering_fields)
|
||||
if valid_fields is None or valid_fields == '__all__':
|
||||
return ordering
|
||||
|
||||
valid_fields = dict(self.get_valid_fields(queryset, view, {'request': request}))
|
||||
converted_fields = []
|
||||
for field in ordering:
|
||||
if field.startswith('-'):
|
||||
converted_fields.append('-' + valid_fields[field[1:]])
|
||||
else:
|
||||
converted_fields.append(valid_fields[field])
|
||||
|
||||
return converted_fields
|
||||
|
||||
def get_default_ordering(self, view):
|
||||
ordering = getattr(view, 'ordering', None)
|
||||
if isinstance(ordering, str):
|
||||
|
@ -328,14 +312,17 @@ class OrderingFilter(BaseFilterBackend):
|
|||
return valid_fields
|
||||
|
||||
def remove_invalid_fields(self, queryset, fields, view, request):
|
||||
valid_fields = [item[0] for item in self.get_valid_fields(queryset, view, {'request': request})]
|
||||
valid_fields = {item[1]: item[0] for item in self.get_valid_fields(queryset, view, {'request': request})}
|
||||
|
||||
def term_valid(term):
|
||||
if term.startswith("-"):
|
||||
term = term[1:]
|
||||
return term in valid_fields
|
||||
return valid_fields.get(term) is not None
|
||||
|
||||
return [term for term in fields if term_valid(term)]
|
||||
return [
|
||||
valid_fields.get(term) if not term.startswith("-") else '-' + valid_fields.get(term[1:])
|
||||
for term in fields if term_valid(term)
|
||||
]
|
||||
|
||||
def filter_queryset(self, request, queryset, view):
|
||||
ordering = self.get_ordering(request, queryset, view)
|
||||
|
|
|
@ -848,6 +848,30 @@ class OrderingFilterTests(TestCase):
|
|||
with self.assertRaises(ImproperlyConfigured):
|
||||
view(request)
|
||||
|
||||
def test_ordering_with_verbose_name(self):
|
||||
for index, obj in enumerate(OrderingFilterModel.objects.all()):
|
||||
OrderingFilterRelatedModel.objects.create(
|
||||
related_object=obj,
|
||||
index=index
|
||||
)
|
||||
|
||||
class OrderingListView(generics.ListAPIView):
|
||||
queryset = OrderingFilterModel.objects.all()
|
||||
serializer_class = OrderingFilterSerializer
|
||||
filter_backends = (filters.OrderingFilter,)
|
||||
ordering = ('title',)
|
||||
ordering_fields = (
|
||||
('relateds__index', '-index'),
|
||||
)
|
||||
|
||||
view = OrderingListView.as_view()
|
||||
request = factory.get('/', {'ordering': 'index'})
|
||||
response = view(request)
|
||||
assert response.data == [
|
||||
{'id': 3, 'title': 'xwv', 'text': 'cde'},
|
||||
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
|
||||
{'id': 1, 'title': 'zyx', 'text': 'abc'},
|
||||
]
|
||||
|
||||
class SensitiveOrderingFilterModel(models.Model):
|
||||
username = models.CharField(max_length=20)
|
||||
|
|
Loading…
Reference in New Issue
Block a user