mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-07-18 12:12:19 +03:00
use better solution and add test
This commit is contained in:
parent
226a37a761
commit
0561ecff33
|
@ -237,29 +237,13 @@ class OrderingFilter(BaseFilterBackend):
|
||||||
params = request.query_params.get(self.ordering_param)
|
params = request.query_params.get(self.ordering_param)
|
||||||
if params:
|
if params:
|
||||||
fields = [param.strip() for param in params.split(',')]
|
fields = [param.strip() for param in params.split(',')]
|
||||||
valid_filed_names = self.remove_invalid_fields(queryset, fields, view, request)
|
ordering = self.remove_invalid_fields(queryset, fields, view, request)
|
||||||
ordering = self.convert_to_origin_filed_name(request, queryset, view, valid_filed_names)
|
|
||||||
if ordering:
|
if ordering:
|
||||||
return ordering
|
return ordering
|
||||||
|
|
||||||
# No ordering was included, or all the ordering fields were invalid
|
# No ordering was included, or all the ordering fields were invalid
|
||||||
return self.get_default_ordering(view)
|
return self.get_default_ordering(view)
|
||||||
|
|
||||||
def convert_to_origin_filed_name(self, request, queryset, view, ordering):
|
|
||||||
valid_fields = getattr(view, 'ordering_fields', self.ordering_fields)
|
|
||||||
if valid_fields is None or valid_fields == '__all__':
|
|
||||||
return ordering
|
|
||||||
|
|
||||||
valid_fields = dict(self.get_valid_fields(queryset, view, {'request': request}))
|
|
||||||
converted_fields = []
|
|
||||||
for field in ordering:
|
|
||||||
if field.startswith('-'):
|
|
||||||
converted_fields.append('-' + valid_fields[field[1:]])
|
|
||||||
else:
|
|
||||||
converted_fields.append(valid_fields[field])
|
|
||||||
|
|
||||||
return converted_fields
|
|
||||||
|
|
||||||
def get_default_ordering(self, view):
|
def get_default_ordering(self, view):
|
||||||
ordering = getattr(view, 'ordering', None)
|
ordering = getattr(view, 'ordering', None)
|
||||||
if isinstance(ordering, str):
|
if isinstance(ordering, str):
|
||||||
|
@ -328,14 +312,17 @@ class OrderingFilter(BaseFilterBackend):
|
||||||
return valid_fields
|
return valid_fields
|
||||||
|
|
||||||
def remove_invalid_fields(self, queryset, fields, view, request):
|
def remove_invalid_fields(self, queryset, fields, view, request):
|
||||||
valid_fields = [item[0] for item in self.get_valid_fields(queryset, view, {'request': request})]
|
valid_fields = {item[1]: item[0] for item in self.get_valid_fields(queryset, view, {'request': request})}
|
||||||
|
|
||||||
def term_valid(term):
|
def term_valid(term):
|
||||||
if term.startswith("-"):
|
if term.startswith("-"):
|
||||||
term = term[1:]
|
term = term[1:]
|
||||||
return term in valid_fields
|
return valid_fields.get(term) is not None
|
||||||
|
|
||||||
return [term for term in fields if term_valid(term)]
|
return [
|
||||||
|
valid_fields.get(term) if not term.startswith("-") else '-' + valid_fields.get(term[1:])
|
||||||
|
for term in fields if term_valid(term)
|
||||||
|
]
|
||||||
|
|
||||||
def filter_queryset(self, request, queryset, view):
|
def filter_queryset(self, request, queryset, view):
|
||||||
ordering = self.get_ordering(request, queryset, view)
|
ordering = self.get_ordering(request, queryset, view)
|
||||||
|
|
|
@ -848,6 +848,30 @@ class OrderingFilterTests(TestCase):
|
||||||
with self.assertRaises(ImproperlyConfigured):
|
with self.assertRaises(ImproperlyConfigured):
|
||||||
view(request)
|
view(request)
|
||||||
|
|
||||||
|
def test_ordering_with_verbose_name(self):
|
||||||
|
for index, obj in enumerate(OrderingFilterModel.objects.all()):
|
||||||
|
OrderingFilterRelatedModel.objects.create(
|
||||||
|
related_object=obj,
|
||||||
|
index=index
|
||||||
|
)
|
||||||
|
|
||||||
|
class OrderingListView(generics.ListAPIView):
|
||||||
|
queryset = OrderingFilterModel.objects.all()
|
||||||
|
serializer_class = OrderingFilterSerializer
|
||||||
|
filter_backends = (filters.OrderingFilter,)
|
||||||
|
ordering = ('title',)
|
||||||
|
ordering_fields = (
|
||||||
|
('relateds__index', '-index'),
|
||||||
|
)
|
||||||
|
|
||||||
|
view = OrderingListView.as_view()
|
||||||
|
request = factory.get('/', {'ordering': 'index'})
|
||||||
|
response = view(request)
|
||||||
|
assert response.data == [
|
||||||
|
{'id': 3, 'title': 'xwv', 'text': 'cde'},
|
||||||
|
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
|
||||||
|
{'id': 1, 'title': 'zyx', 'text': 'abc'},
|
||||||
|
]
|
||||||
|
|
||||||
class SensitiveOrderingFilterModel(models.Model):
|
class SensitiveOrderingFilterModel(models.Model):
|
||||||
username = models.CharField(max_length=20)
|
username = models.CharField(max_length=20)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user