use better solution and add test

This commit is contained in:
ali 2023-08-05 02:24:10 +03:30
parent 226a37a761
commit 0561ecff33
2 changed files with 31 additions and 20 deletions

View File

@ -237,29 +237,13 @@ class OrderingFilter(BaseFilterBackend):
params = request.query_params.get(self.ordering_param)
if params:
fields = [param.strip() for param in params.split(',')]
valid_filed_names = self.remove_invalid_fields(queryset, fields, view, request)
ordering = self.convert_to_origin_filed_name(request, queryset, view, valid_filed_names)
ordering = self.remove_invalid_fields(queryset, fields, view, request)
if ordering:
return ordering
# No ordering was included, or all the ordering fields were invalid
return self.get_default_ordering(view)
def convert_to_origin_filed_name(self, request, queryset, view, ordering):
valid_fields = getattr(view, 'ordering_fields', self.ordering_fields)
if valid_fields is None or valid_fields == '__all__':
return ordering
valid_fields = dict(self.get_valid_fields(queryset, view, {'request': request}))
converted_fields = []
for field in ordering:
if field.startswith('-'):
converted_fields.append('-' + valid_fields[field[1:]])
else:
converted_fields.append(valid_fields[field])
return converted_fields
def get_default_ordering(self, view):
ordering = getattr(view, 'ordering', None)
if isinstance(ordering, str):
@ -328,14 +312,17 @@ class OrderingFilter(BaseFilterBackend):
return valid_fields
def remove_invalid_fields(self, queryset, fields, view, request):
valid_fields = [item[0] for item in self.get_valid_fields(queryset, view, {'request': request})]
valid_fields = {item[1]: item[0] for item in self.get_valid_fields(queryset, view, {'request': request})}
def term_valid(term):
if term.startswith("-"):
term = term[1:]
return term in valid_fields
return valid_fields.get(term) is not None
return [term for term in fields if term_valid(term)]
return [
valid_fields.get(term) if not term.startswith("-") else '-' + valid_fields.get(term[1:])
for term in fields if term_valid(term)
]
def filter_queryset(self, request, queryset, view):
ordering = self.get_ordering(request, queryset, view)

View File

@ -848,6 +848,30 @@ class OrderingFilterTests(TestCase):
with self.assertRaises(ImproperlyConfigured):
view(request)
def test_ordering_with_verbose_name(self):
for index, obj in enumerate(OrderingFilterModel.objects.all()):
OrderingFilterRelatedModel.objects.create(
related_object=obj,
index=index
)
class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,)
ordering = ('title',)
ordering_fields = (
('relateds__index', '-index'),
)
view = OrderingListView.as_view()
request = factory.get('/', {'ordering': 'index'})
response = view(request)
assert response.data == [
{'id': 3, 'title': 'xwv', 'text': 'cde'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{'id': 1, 'title': 'zyx', 'text': 'abc'},
]
class SensitiveOrderingFilterModel(models.Model):
username = models.CharField(max_length=20)