Add more functionalities to OrderingFilter

This commit is contained in:
maksjood 2024-10-14 23:01:52 +03:30
parent d3dd45b3f4
commit ce35de5707

View File

@ -5,6 +5,7 @@ returned by list views.
import operator
import warnings
from functools import reduce
from typing import Iterable
from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured
from django.db import models
@ -217,7 +218,30 @@ class SearchFilter(BaseFilterBackend):
]
class OrderingExpressionFactory:
def __init__(self, queryset_value: str = '', nulls_as_low: bool = True):
self.queryset_value = queryset_value
self.nulls_as_low = nulls_as_low
def get_expression(self, given_value: str):
queryset_value = self.queryset_value or given_value
if given_value.startswith('-'):
return models.F(queryset_value.strip('-')).desc(nulls_last=self.nulls_as_low)
else:
return models.F(queryset_value).asc(nulls_first=self.nulls_as_low)
class OrderingFilter(BaseFilterBackend):
'''
If you use this class, the ordering fields can be a tuple of
(<query_value>, <OrderingExpressionFactory instance>). Using `OrderingExpressionFactory`,
you can map query_param values to certain values in the queryset and whether `null` values
are considered as low or high.
Using this class, you cas set `ordering_prefix (Iterable)` attribute to the view. This prefix would be
considered as the first ordering expressions.
'''
# The URL query parameter used for the ordering.
ordering_param = api_settings.ORDERING_PARAM
ordering_fields = None
@ -225,7 +249,7 @@ class OrderingFilter(BaseFilterBackend):
ordering_description = _('Which field to use when ordering the results.')
template = 'rest_framework/filters/ordering.html'
def get_ordering(self, request, queryset, view):
def get_ordering(self, request, queryset, view): # returns an iterable of expressions for ordering
"""
Ordering is set by a comma delimited ?ordering=... query parameter.
@ -233,16 +257,29 @@ class OrderingFilter(BaseFilterBackend):
the `ordering_param` value on the OrderingFilter or by
specifying an `ORDERING_PARAM` value in the API settings.
"""
params = request.query_params.get(self.ordering_param)
if params:
fields = [param.strip() for param in params.split(',')]
ordering = self.remove_invalid_fields(queryset, fields, view, request)
if ordering:
return ordering
params = request.query_params.get(self.ordering_param, '')
if params or getattr(view, 'ordering_prefix', []):
params = [param.strip() for param in params.split(',')]
ordering_expressions = self.get_ordering_expressions(queryset, params, view, request)
if ordering_expressions:
return ordering_expressions
# No ordering was included, or all the ordering fields were invalid
return self.get_default_ordering(view)
def get_ordering_expressions(self, queryset, fields_in_query: Iterable[str], view, request):
valid_fields = dict(self.get_valid_fields(queryset, view, {'request': request}))
ordering_expressions = list(getattr(view, 'ordering_prefix', [])) or []
for field in fields_in_query:
exp = valid_fields.get(field, None) or valid_fields.get(field[1:], None)
if exp:
ordering_expressions.append(
exp.get_expression(field) if isinstance(exp, OrderingExpressionFactory)
else field
)
return ordering_expressions
def get_default_ordering(self, view):
ordering = getattr(view, 'ordering', None)
if isinstance(ordering, str):