mirror of
https://github.com/encode/django-rest-framework.git
synced 2024-11-24 10:34:03 +03:00
Add more functionalities to OrderingFilter
This commit is contained in:
parent
d3dd45b3f4
commit
ce35de5707
|
@ -5,6 +5,7 @@ returned by list views.
|
||||||
import operator
|
import operator
|
||||||
import warnings
|
import warnings
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured
|
from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured
|
||||||
from django.db import models
|
from django.db import models
|
||||||
|
@ -217,7 +218,30 @@ class SearchFilter(BaseFilterBackend):
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class OrderingExpressionFactory:
|
||||||
|
|
||||||
|
def __init__(self, queryset_value: str = '', nulls_as_low: bool = True):
|
||||||
|
self.queryset_value = queryset_value
|
||||||
|
self.nulls_as_low = nulls_as_low
|
||||||
|
|
||||||
|
def get_expression(self, given_value: str):
|
||||||
|
queryset_value = self.queryset_value or given_value
|
||||||
|
if given_value.startswith('-'):
|
||||||
|
return models.F(queryset_value.strip('-')).desc(nulls_last=self.nulls_as_low)
|
||||||
|
else:
|
||||||
|
return models.F(queryset_value).asc(nulls_first=self.nulls_as_low)
|
||||||
|
|
||||||
|
|
||||||
class OrderingFilter(BaseFilterBackend):
|
class OrderingFilter(BaseFilterBackend):
|
||||||
|
'''
|
||||||
|
If you use this class, the ordering fields can be a tuple of
|
||||||
|
(<query_value>, <OrderingExpressionFactory instance>). Using `OrderingExpressionFactory`,
|
||||||
|
you can map query_param values to certain values in the queryset and whether `null` values
|
||||||
|
are considered as low or high.
|
||||||
|
|
||||||
|
Using this class, you cas set `ordering_prefix (Iterable)` attribute to the view. This prefix would be
|
||||||
|
considered as the first ordering expressions.
|
||||||
|
'''
|
||||||
# The URL query parameter used for the ordering.
|
# The URL query parameter used for the ordering.
|
||||||
ordering_param = api_settings.ORDERING_PARAM
|
ordering_param = api_settings.ORDERING_PARAM
|
||||||
ordering_fields = None
|
ordering_fields = None
|
||||||
|
@ -225,7 +249,7 @@ class OrderingFilter(BaseFilterBackend):
|
||||||
ordering_description = _('Which field to use when ordering the results.')
|
ordering_description = _('Which field to use when ordering the results.')
|
||||||
template = 'rest_framework/filters/ordering.html'
|
template = 'rest_framework/filters/ordering.html'
|
||||||
|
|
||||||
def get_ordering(self, request, queryset, view):
|
def get_ordering(self, request, queryset, view): # returns an iterable of expressions for ordering
|
||||||
"""
|
"""
|
||||||
Ordering is set by a comma delimited ?ordering=... query parameter.
|
Ordering is set by a comma delimited ?ordering=... query parameter.
|
||||||
|
|
||||||
|
@ -233,16 +257,29 @@ class OrderingFilter(BaseFilterBackend):
|
||||||
the `ordering_param` value on the OrderingFilter or by
|
the `ordering_param` value on the OrderingFilter or by
|
||||||
specifying an `ORDERING_PARAM` value in the API settings.
|
specifying an `ORDERING_PARAM` value in the API settings.
|
||||||
"""
|
"""
|
||||||
params = request.query_params.get(self.ordering_param)
|
params = request.query_params.get(self.ordering_param, '')
|
||||||
if params:
|
if params or getattr(view, 'ordering_prefix', []):
|
||||||
fields = [param.strip() for param in params.split(',')]
|
params = [param.strip() for param in params.split(',')]
|
||||||
ordering = self.remove_invalid_fields(queryset, fields, view, request)
|
ordering_expressions = self.get_ordering_expressions(queryset, params, view, request)
|
||||||
if ordering:
|
if ordering_expressions:
|
||||||
return ordering
|
return ordering_expressions
|
||||||
|
|
||||||
# No ordering was included, or all the ordering fields were invalid
|
# No ordering was included, or all the ordering fields were invalid
|
||||||
return self.get_default_ordering(view)
|
return self.get_default_ordering(view)
|
||||||
|
|
||||||
|
def get_ordering_expressions(self, queryset, fields_in_query: Iterable[str], view, request):
|
||||||
|
valid_fields = dict(self.get_valid_fields(queryset, view, {'request': request}))
|
||||||
|
|
||||||
|
ordering_expressions = list(getattr(view, 'ordering_prefix', [])) or []
|
||||||
|
for field in fields_in_query:
|
||||||
|
exp = valid_fields.get(field, None) or valid_fields.get(field[1:], None)
|
||||||
|
if exp:
|
||||||
|
ordering_expressions.append(
|
||||||
|
exp.get_expression(field) if isinstance(exp, OrderingExpressionFactory)
|
||||||
|
else field
|
||||||
|
)
|
||||||
|
return ordering_expressions
|
||||||
|
|
||||||
def get_default_ordering(self, view):
|
def get_default_ordering(self, view):
|
||||||
ordering = getattr(view, 'ordering', None)
|
ordering = getattr(view, 'ordering', None)
|
||||||
if isinstance(ordering, str):
|
if isinstance(ordering, str):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user