Add more functionalities to OrderingFilter

This commit is contained in:
maksjood 2024-10-14 23:01:52 +03:30
parent d3dd45b3f4
commit ce35de5707

View File

@ -5,6 +5,7 @@ returned by list views.
import operator import operator
import warnings import warnings
from functools import reduce from functools import reduce
from typing import Iterable
from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured
from django.db import models from django.db import models
@ -217,7 +218,30 @@ class SearchFilter(BaseFilterBackend):
] ]
class OrderingExpressionFactory:
def __init__(self, queryset_value: str = '', nulls_as_low: bool = True):
self.queryset_value = queryset_value
self.nulls_as_low = nulls_as_low
def get_expression(self, given_value: str):
queryset_value = self.queryset_value or given_value
if given_value.startswith('-'):
return models.F(queryset_value.strip('-')).desc(nulls_last=self.nulls_as_low)
else:
return models.F(queryset_value).asc(nulls_first=self.nulls_as_low)
class OrderingFilter(BaseFilterBackend): class OrderingFilter(BaseFilterBackend):
'''
If you use this class, the ordering fields can be a tuple of
(<query_value>, <OrderingExpressionFactory instance>). Using `OrderingExpressionFactory`,
you can map query_param values to certain values in the queryset and whether `null` values
are considered as low or high.
Using this class, you cas set `ordering_prefix (Iterable)` attribute to the view. This prefix would be
considered as the first ordering expressions.
'''
# The URL query parameter used for the ordering. # The URL query parameter used for the ordering.
ordering_param = api_settings.ORDERING_PARAM ordering_param = api_settings.ORDERING_PARAM
ordering_fields = None ordering_fields = None
@ -225,7 +249,7 @@ class OrderingFilter(BaseFilterBackend):
ordering_description = _('Which field to use when ordering the results.') ordering_description = _('Which field to use when ordering the results.')
template = 'rest_framework/filters/ordering.html' template = 'rest_framework/filters/ordering.html'
def get_ordering(self, request, queryset, view): def get_ordering(self, request, queryset, view): # returns an iterable of expressions for ordering
""" """
Ordering is set by a comma delimited ?ordering=... query parameter. Ordering is set by a comma delimited ?ordering=... query parameter.
@ -233,16 +257,29 @@ class OrderingFilter(BaseFilterBackend):
the `ordering_param` value on the OrderingFilter or by the `ordering_param` value on the OrderingFilter or by
specifying an `ORDERING_PARAM` value in the API settings. specifying an `ORDERING_PARAM` value in the API settings.
""" """
params = request.query_params.get(self.ordering_param) params = request.query_params.get(self.ordering_param, '')
if params: if params or getattr(view, 'ordering_prefix', []):
fields = [param.strip() for param in params.split(',')] params = [param.strip() for param in params.split(',')]
ordering = self.remove_invalid_fields(queryset, fields, view, request) ordering_expressions = self.get_ordering_expressions(queryset, params, view, request)
if ordering: if ordering_expressions:
return ordering return ordering_expressions
# No ordering was included, or all the ordering fields were invalid # No ordering was included, or all the ordering fields were invalid
return self.get_default_ordering(view) return self.get_default_ordering(view)
def get_ordering_expressions(self, queryset, fields_in_query: Iterable[str], view, request):
valid_fields = dict(self.get_valid_fields(queryset, view, {'request': request}))
ordering_expressions = list(getattr(view, 'ordering_prefix', [])) or []
for field in fields_in_query:
exp = valid_fields.get(field, None) or valid_fields.get(field[1:], None)
if exp:
ordering_expressions.append(
exp.get_expression(field) if isinstance(exp, OrderingExpressionFactory)
else field
)
return ordering_expressions
def get_default_ordering(self, view): def get_default_ordering(self, view):
ordering = getattr(view, 'ordering', None) ordering = getattr(view, 'ordering', None)
if isinstance(ordering, str): if isinstance(ordering, str):