From ce35de5707078d80cd6a9be74420752b2e23ebfe Mon Sep 17 00:00:00 2001 From: maksjood Date: Mon, 14 Oct 2024 23:01:52 +0330 Subject: [PATCH] Add more functionalities to OrderingFilter --- rest_framework/filters.py | 51 +++++++++++++++++++++++++++++++++------ 1 file changed, 44 insertions(+), 7 deletions(-) diff --git a/rest_framework/filters.py b/rest_framework/filters.py index 3f4730da8..65832b017 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -5,6 +5,7 @@ returned by list views. import operator import warnings from functools import reduce +from typing import Iterable from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured from django.db import models @@ -217,7 +218,30 @@ class SearchFilter(BaseFilterBackend): ] +class OrderingExpressionFactory: + + def __init__(self, queryset_value: str = '', nulls_as_low: bool = True): + self.queryset_value = queryset_value + self.nulls_as_low = nulls_as_low + + def get_expression(self, given_value: str): + queryset_value = self.queryset_value or given_value + if given_value.startswith('-'): + return models.F(queryset_value.strip('-')).desc(nulls_last=self.nulls_as_low) + else: + return models.F(queryset_value).asc(nulls_first=self.nulls_as_low) + + class OrderingFilter(BaseFilterBackend): + ''' + If you use this class, the ordering fields can be a tuple of + (, ). Using `OrderingExpressionFactory`, + you can map query_param values to certain values in the queryset and whether `null` values + are considered as low or high. + + Using this class, you cas set `ordering_prefix (Iterable)` attribute to the view. This prefix would be + considered as the first ordering expressions. + ''' # The URL query parameter used for the ordering. ordering_param = api_settings.ORDERING_PARAM ordering_fields = None @@ -225,7 +249,7 @@ class OrderingFilter(BaseFilterBackend): ordering_description = _('Which field to use when ordering the results.') template = 'rest_framework/filters/ordering.html' - def get_ordering(self, request, queryset, view): + def get_ordering(self, request, queryset, view): # returns an iterable of expressions for ordering """ Ordering is set by a comma delimited ?ordering=... query parameter. @@ -233,16 +257,29 @@ class OrderingFilter(BaseFilterBackend): the `ordering_param` value on the OrderingFilter or by specifying an `ORDERING_PARAM` value in the API settings. """ - params = request.query_params.get(self.ordering_param) - if params: - fields = [param.strip() for param in params.split(',')] - ordering = self.remove_invalid_fields(queryset, fields, view, request) - if ordering: - return ordering + params = request.query_params.get(self.ordering_param, '') + if params or getattr(view, 'ordering_prefix', []): + params = [param.strip() for param in params.split(',')] + ordering_expressions = self.get_ordering_expressions(queryset, params, view, request) + if ordering_expressions: + return ordering_expressions # No ordering was included, or all the ordering fields were invalid return self.get_default_ordering(view) + def get_ordering_expressions(self, queryset, fields_in_query: Iterable[str], view, request): + valid_fields = dict(self.get_valid_fields(queryset, view, {'request': request})) + + ordering_expressions = list(getattr(view, 'ordering_prefix', [])) or [] + for field in fields_in_query: + exp = valid_fields.get(field, None) or valid_fields.get(field[1:], None) + if exp: + ordering_expressions.append( + exp.get_expression(field) if isinstance(exp, OrderingExpressionFactory) + else field + ) + return ordering_expressions + def get_default_ordering(self, view): ordering = getattr(view, 'ordering', None) if isinstance(ordering, str):