Cursor pagination now works with OrderingFilter

This commit is contained in:
Tom Christie 2015-01-22 16:12:05 +00:00
parent 408261ee02
commit 0822c9e558
3 changed files with 82 additions and 23 deletions

View File

@ -114,7 +114,7 @@ class OrderingFilter(BaseFilterBackend):
ordering_param = api_settings.ORDERING_PARAM ordering_param = api_settings.ORDERING_PARAM
ordering_fields = None ordering_fields = None
def get_ordering(self, request): def get_ordering(self, request, queryset, view):
""" """
Ordering is set by a comma delimited ?ordering=... query parameter. Ordering is set by a comma delimited ?ordering=... query parameter.
@ -124,7 +124,13 @@ class OrderingFilter(BaseFilterBackend):
""" """
params = request.query_params.get(self.ordering_param) params = request.query_params.get(self.ordering_param)
if params: if params:
return [param.strip() for param in params.split(',')] fields = [param.strip() for param in params.split(',')]
ordering = self.remove_invalid_fields(queryset, fields, view)
if ordering:
return ordering
# No ordering was included, or all the ordering fields were invalid
return self.get_default_ordering(view)
def get_default_ordering(self, view): def get_default_ordering(self, view):
ordering = getattr(view, 'ordering', None) ordering = getattr(view, 'ordering', None)
@ -132,7 +138,7 @@ class OrderingFilter(BaseFilterBackend):
return (ordering,) return (ordering,)
return ordering return ordering
def remove_invalid_fields(self, queryset, ordering, view): def remove_invalid_fields(self, queryset, fields, view):
valid_fields = getattr(view, 'ordering_fields', self.ordering_fields) valid_fields = getattr(view, 'ordering_fields', self.ordering_fields)
if valid_fields is None: if valid_fields is None:
@ -152,18 +158,10 @@ class OrderingFilter(BaseFilterBackend):
valid_fields = [field.name for field in queryset.model._meta.fields] valid_fields = [field.name for field in queryset.model._meta.fields]
valid_fields += queryset.query.aggregates.keys() valid_fields += queryset.query.aggregates.keys()
return [term for term in ordering if term.lstrip('-') in valid_fields] return [term for term in fields if term.lstrip('-') in valid_fields]
def filter_queryset(self, request, queryset, view): def filter_queryset(self, request, queryset, view):
ordering = self.get_ordering(request) ordering = self.get_ordering(request, queryset, view)
if ordering:
# Skip any incorrect parameters
ordering = self.remove_invalid_fields(queryset, ordering, view)
if not ordering:
# Use 'ordering' attribute by default
ordering = self.get_default_ordering(view)
if ordering: if ordering:
return queryset.order_by(*ordering) return queryset.order_by(*ordering)

View File

@ -427,8 +427,9 @@ class LimitOffsetPagination(BasePagination):
class CursorPagination(BasePagination): class CursorPagination(BasePagination):
# Support usage with OrderingFilter # Determine how/if True, False and None positions work - do the string
# Determine how/if True, False and None positions work # encodings work with Django queryset filters?
# Consider a max offset cap.
cursor_query_param = 'cursor' cursor_query_param = 'cursor'
page_size = api_settings.PAGINATE_BY page_size = api_settings.PAGINATE_BY
invalid_cursor_message = _('Invalid cursor') invalid_cursor_message = _('Invalid cursor')
@ -436,7 +437,7 @@ class CursorPagination(BasePagination):
def paginate_queryset(self, queryset, request, view=None): def paginate_queryset(self, queryset, request, view=None):
self.base_url = request.build_absolute_uri() self.base_url = request.build_absolute_uri()
self.ordering = self.get_ordering(view) self.ordering = self.get_ordering(request, queryset, view)
# Determine if we have a cursor, and if so then decode it. # Determine if we have a cursor, and if so then decode it.
encoded = request.query_params.get(self.cursor_query_param) encoded = request.query_params.get(self.cursor_query_param)
@ -600,16 +601,36 @@ class CursorPagination(BasePagination):
encoded = _encode_cursor(cursor) encoded = _encode_cursor(cursor)
return replace_query_param(self.base_url, self.cursor_query_param, encoded) return replace_query_param(self.base_url, self.cursor_query_param, encoded)
def get_ordering(self, view): def get_ordering(self, request, queryset, view):
""" """
Return a tuple of strings, that may be used in an `order_by` method. Return a tuple of strings, that may be used in an `order_by` method.
""" """
ordering = getattr(view, 'ordering', getattr(self, 'ordering', None)) ordering_filters = [
filter_cls for filter_cls in getattr(view, 'filter_backends', [])
if hasattr(filter_cls, 'get_ordering')
]
if ordering_filters:
# If a filter exists on the view that implements `get_ordering`
# then we defer to that filter to determine the ordering.
filter_cls = ordering_filters[0]
filter_instance = filter_cls()
ordering = filter_instance.get_ordering(request, queryset, view)
assert ordering is not None, (
'Using cursor pagination, but filter class {filter_cls} '
'returned a `None` ordering.'.format(
filter_cls=filter_cls.__name__
)
)
else:
# The default case is to check for an `ordering` attribute,
# first on the view instance, and then on this pagination instance.
ordering = getattr(view, 'ordering', getattr(self, 'ordering', None))
assert ordering is not None, ( assert ordering is not None, (
'Using cursor pagination, but no ordering attribute was declared ' 'Using cursor pagination, but no ordering attribute was declared '
'on the view or on the pagination class.' 'on the view or on the pagination class.'
) )
assert isinstance(ordering, (six.string_types, list, tuple)), ( assert isinstance(ordering, (six.string_types, list, tuple)), (
'Invalid ordering. Expected string or tuple, but got {type}'.format( 'Invalid ordering. Expected string or tuple, but got {type}'.format(
type=type(ordering).__name__ type=type(ordering).__name__
@ -618,7 +639,7 @@ class CursorPagination(BasePagination):
if isinstance(ordering, six.string_types): if isinstance(ordering, six.string_types):
return (ordering,) return (ordering,)
return ordering return tuple(ordering)
def _get_position_from_instance(self, instance, ordering): def _get_position_from_instance(self, instance, ordering):
attr = getattr(instance, ordering[0]) attr = getattr(instance, ordering[0])

View File

@ -77,6 +77,20 @@ class TestPaginationIntegration:
'count': 50 'count': 50
} }
def test_setting_page_size_to_zero(self):
"""
When page_size parameter is invalid it should return to the default.
"""
request = factory.get('/', {'page_size': 0})
response = self.view(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {
'results': [2, 4, 6, 8, 10],
'previous': None,
'next': 'http://testserver/?page=2&page_size=0',
'count': 50
}
def test_additional_query_params_are_preserved(self): def test_additional_query_params_are_preserved(self):
request = factory.get('/', {'page': 2, 'filter': 'even'}) request = factory.get('/', {'page': 2, 'filter': 'even'})
response = self.view(request) response = self.view(request)
@ -88,6 +102,14 @@ class TestPaginationIntegration:
'count': 50 'count': 50
} }
def test_404_not_found_for_zero_page(self):
request = factory.get('/', {'page': '0'})
response = self.view(request)
assert response.status_code == status.HTTP_404_NOT_FOUND
assert response.data == {
'detail': 'Invalid page "0": That page number is less than 1.'
}
def test_404_not_found_for_invalid_page(self): def test_404_not_found_for_invalid_page(self):
request = factory.get('/', {'page': 'invalid'}) request = factory.get('/', {'page': 'invalid'})
response = self.view(request) response = self.view(request)
@ -507,6 +529,24 @@ class TestCursorPagination:
with pytest.raises(exceptions.NotFound): with pytest.raises(exceptions.NotFound):
self.pagination.paginate_queryset(self.queryset, request) self.pagination.paginate_queryset(self.queryset, request)
def test_use_with_ordering_filter(self):
class MockView:
filter_backends = (filters.OrderingFilter,)
ordering_fields = ['username', 'created']
ordering = 'created'
request = Request(factory.get('/', {'ordering': 'username'}))
ordering = self.pagination.get_ordering(request, [], MockView())
assert ordering == ('username',)
request = Request(factory.get('/', {'ordering': '-username'}))
ordering = self.pagination.get_ordering(request, [], MockView())
assert ordering == ('-username',)
request = Request(factory.get('/', {'ordering': 'invalid'}))
ordering = self.pagination.get_ordering(request, [], MockView())
assert ordering == ('created',)
def test_cursor_pagination(self): def test_cursor_pagination(self):
(previous, current, next, previous_url, next_url) = self.get_pages('/') (previous, current, next, previous_url, next_url) = self.get_pages('/')