From 0822c9e55820f8e4737329e38abc2e21718af9e5 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 22 Jan 2015 16:12:05 +0000 Subject: [PATCH] Cursor pagination now works with OrderingFilter --- rest_framework/filters.py | 24 ++++++++++----------- rest_framework/pagination.py | 41 +++++++++++++++++++++++++++--------- tests/test_pagination.py | 40 +++++++++++++++++++++++++++++++++++ 3 files changed, 82 insertions(+), 23 deletions(-) diff --git a/rest_framework/filters.py b/rest_framework/filters.py index d188a2d1e..2bcf36991 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -114,7 +114,7 @@ class OrderingFilter(BaseFilterBackend): ordering_param = api_settings.ORDERING_PARAM ordering_fields = None - def get_ordering(self, request): + def get_ordering(self, request, queryset, view): """ Ordering is set by a comma delimited ?ordering=... query parameter. @@ -124,7 +124,13 @@ class OrderingFilter(BaseFilterBackend): """ params = request.query_params.get(self.ordering_param) if params: - return [param.strip() for param in params.split(',')] + fields = [param.strip() for param in params.split(',')] + ordering = self.remove_invalid_fields(queryset, fields, view) + if ordering: + return ordering + + # No ordering was included, or all the ordering fields were invalid + return self.get_default_ordering(view) def get_default_ordering(self, view): ordering = getattr(view, 'ordering', None) @@ -132,7 +138,7 @@ class OrderingFilter(BaseFilterBackend): return (ordering,) return ordering - def remove_invalid_fields(self, queryset, ordering, view): + def remove_invalid_fields(self, queryset, fields, view): valid_fields = getattr(view, 'ordering_fields', self.ordering_fields) if valid_fields is None: @@ -152,18 +158,10 @@ class OrderingFilter(BaseFilterBackend): valid_fields = [field.name for field in queryset.model._meta.fields] valid_fields += queryset.query.aggregates.keys() - return [term for term in ordering if term.lstrip('-') in valid_fields] + return [term for term in fields if term.lstrip('-') in valid_fields] def filter_queryset(self, request, queryset, view): - ordering = self.get_ordering(request) - - if ordering: - # Skip any incorrect parameters - ordering = self.remove_invalid_fields(queryset, ordering, view) - - if not ordering: - # Use 'ordering' attribute by default - ordering = self.get_default_ordering(view) + ordering = self.get_ordering(request, queryset, view) if ordering: return queryset.order_by(*ordering) diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index 7b28b47f0..1b4174bc6 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -427,8 +427,9 @@ class LimitOffsetPagination(BasePagination): class CursorPagination(BasePagination): - # Support usage with OrderingFilter - # Determine how/if True, False and None positions work + # Determine how/if True, False and None positions work - do the string + # encodings work with Django queryset filters? + # Consider a max offset cap. cursor_query_param = 'cursor' page_size = api_settings.PAGINATE_BY invalid_cursor_message = _('Invalid cursor') @@ -436,7 +437,7 @@ class CursorPagination(BasePagination): def paginate_queryset(self, queryset, request, view=None): self.base_url = request.build_absolute_uri() - self.ordering = self.get_ordering(view) + self.ordering = self.get_ordering(request, queryset, view) # Determine if we have a cursor, and if so then decode it. encoded = request.query_params.get(self.cursor_query_param) @@ -600,16 +601,36 @@ class CursorPagination(BasePagination): encoded = _encode_cursor(cursor) return replace_query_param(self.base_url, self.cursor_query_param, encoded) - def get_ordering(self, view): + def get_ordering(self, request, queryset, view): """ Return a tuple of strings, that may be used in an `order_by` method. """ - ordering = getattr(view, 'ordering', getattr(self, 'ordering', None)) + ordering_filters = [ + filter_cls for filter_cls in getattr(view, 'filter_backends', []) + if hasattr(filter_cls, 'get_ordering') + ] + + if ordering_filters: + # If a filter exists on the view that implements `get_ordering` + # then we defer to that filter to determine the ordering. + filter_cls = ordering_filters[0] + filter_instance = filter_cls() + ordering = filter_instance.get_ordering(request, queryset, view) + assert ordering is not None, ( + 'Using cursor pagination, but filter class {filter_cls} ' + 'returned a `None` ordering.'.format( + filter_cls=filter_cls.__name__ + ) + ) + else: + # The default case is to check for an `ordering` attribute, + # first on the view instance, and then on this pagination instance. + ordering = getattr(view, 'ordering', getattr(self, 'ordering', None)) + assert ordering is not None, ( + 'Using cursor pagination, but no ordering attribute was declared ' + 'on the view or on the pagination class.' + ) - assert ordering is not None, ( - 'Using cursor pagination, but no ordering attribute was declared ' - 'on the view or on the pagination class.' - ) assert isinstance(ordering, (six.string_types, list, tuple)), ( 'Invalid ordering. Expected string or tuple, but got {type}'.format( type=type(ordering).__name__ @@ -618,7 +639,7 @@ class CursorPagination(BasePagination): if isinstance(ordering, six.string_types): return (ordering,) - return ordering + return tuple(ordering) def _get_position_from_instance(self, instance, ordering): attr = getattr(instance, ordering[0]) diff --git a/tests/test_pagination.py b/tests/test_pagination.py index c05b4abab..338be610c 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -77,6 +77,20 @@ class TestPaginationIntegration: 'count': 50 } + def test_setting_page_size_to_zero(self): + """ + When page_size parameter is invalid it should return to the default. + """ + request = factory.get('/', {'page_size': 0}) + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == { + 'results': [2, 4, 6, 8, 10], + 'previous': None, + 'next': 'http://testserver/?page=2&page_size=0', + 'count': 50 + } + def test_additional_query_params_are_preserved(self): request = factory.get('/', {'page': 2, 'filter': 'even'}) response = self.view(request) @@ -88,6 +102,14 @@ class TestPaginationIntegration: 'count': 50 } + def test_404_not_found_for_zero_page(self): + request = factory.get('/', {'page': '0'}) + response = self.view(request) + assert response.status_code == status.HTTP_404_NOT_FOUND + assert response.data == { + 'detail': 'Invalid page "0": That page number is less than 1.' + } + def test_404_not_found_for_invalid_page(self): request = factory.get('/', {'page': 'invalid'}) response = self.view(request) @@ -507,6 +529,24 @@ class TestCursorPagination: with pytest.raises(exceptions.NotFound): self.pagination.paginate_queryset(self.queryset, request) + def test_use_with_ordering_filter(self): + class MockView: + filter_backends = (filters.OrderingFilter,) + ordering_fields = ['username', 'created'] + ordering = 'created' + + request = Request(factory.get('/', {'ordering': 'username'})) + ordering = self.pagination.get_ordering(request, [], MockView()) + assert ordering == ('username',) + + request = Request(factory.get('/', {'ordering': '-username'})) + ordering = self.pagination.get_ordering(request, [], MockView()) + assert ordering == ('-username',) + + request = Request(factory.get('/', {'ordering': 'invalid'})) + ordering = self.pagination.get_ordering(request, [], MockView()) + assert ordering == ('created',) + def test_cursor_pagination(self): (previous, current, next, previous_url, next_url) = self.get_pages('/')