mirror of
https://github.com/encode/django-rest-framework.git
synced 2024-11-29 13:04:03 +03:00
Cursor pagination now works with OrderingFilter
This commit is contained in:
parent
408261ee02
commit
0822c9e558
|
@ -114,7 +114,7 @@ class OrderingFilter(BaseFilterBackend):
|
||||||
ordering_param = api_settings.ORDERING_PARAM
|
ordering_param = api_settings.ORDERING_PARAM
|
||||||
ordering_fields = None
|
ordering_fields = None
|
||||||
|
|
||||||
def get_ordering(self, request):
|
def get_ordering(self, request, queryset, view):
|
||||||
"""
|
"""
|
||||||
Ordering is set by a comma delimited ?ordering=... query parameter.
|
Ordering is set by a comma delimited ?ordering=... query parameter.
|
||||||
|
|
||||||
|
@ -124,7 +124,13 @@ class OrderingFilter(BaseFilterBackend):
|
||||||
"""
|
"""
|
||||||
params = request.query_params.get(self.ordering_param)
|
params = request.query_params.get(self.ordering_param)
|
||||||
if params:
|
if params:
|
||||||
return [param.strip() for param in params.split(',')]
|
fields = [param.strip() for param in params.split(',')]
|
||||||
|
ordering = self.remove_invalid_fields(queryset, fields, view)
|
||||||
|
if ordering:
|
||||||
|
return ordering
|
||||||
|
|
||||||
|
# No ordering was included, or all the ordering fields were invalid
|
||||||
|
return self.get_default_ordering(view)
|
||||||
|
|
||||||
def get_default_ordering(self, view):
|
def get_default_ordering(self, view):
|
||||||
ordering = getattr(view, 'ordering', None)
|
ordering = getattr(view, 'ordering', None)
|
||||||
|
@ -132,7 +138,7 @@ class OrderingFilter(BaseFilterBackend):
|
||||||
return (ordering,)
|
return (ordering,)
|
||||||
return ordering
|
return ordering
|
||||||
|
|
||||||
def remove_invalid_fields(self, queryset, ordering, view):
|
def remove_invalid_fields(self, queryset, fields, view):
|
||||||
valid_fields = getattr(view, 'ordering_fields', self.ordering_fields)
|
valid_fields = getattr(view, 'ordering_fields', self.ordering_fields)
|
||||||
|
|
||||||
if valid_fields is None:
|
if valid_fields is None:
|
||||||
|
@ -152,18 +158,10 @@ class OrderingFilter(BaseFilterBackend):
|
||||||
valid_fields = [field.name for field in queryset.model._meta.fields]
|
valid_fields = [field.name for field in queryset.model._meta.fields]
|
||||||
valid_fields += queryset.query.aggregates.keys()
|
valid_fields += queryset.query.aggregates.keys()
|
||||||
|
|
||||||
return [term for term in ordering if term.lstrip('-') in valid_fields]
|
return [term for term in fields if term.lstrip('-') in valid_fields]
|
||||||
|
|
||||||
def filter_queryset(self, request, queryset, view):
|
def filter_queryset(self, request, queryset, view):
|
||||||
ordering = self.get_ordering(request)
|
ordering = self.get_ordering(request, queryset, view)
|
||||||
|
|
||||||
if ordering:
|
|
||||||
# Skip any incorrect parameters
|
|
||||||
ordering = self.remove_invalid_fields(queryset, ordering, view)
|
|
||||||
|
|
||||||
if not ordering:
|
|
||||||
# Use 'ordering' attribute by default
|
|
||||||
ordering = self.get_default_ordering(view)
|
|
||||||
|
|
||||||
if ordering:
|
if ordering:
|
||||||
return queryset.order_by(*ordering)
|
return queryset.order_by(*ordering)
|
||||||
|
|
|
@ -427,8 +427,9 @@ class LimitOffsetPagination(BasePagination):
|
||||||
|
|
||||||
|
|
||||||
class CursorPagination(BasePagination):
|
class CursorPagination(BasePagination):
|
||||||
# Support usage with OrderingFilter
|
# Determine how/if True, False and None positions work - do the string
|
||||||
# Determine how/if True, False and None positions work
|
# encodings work with Django queryset filters?
|
||||||
|
# Consider a max offset cap.
|
||||||
cursor_query_param = 'cursor'
|
cursor_query_param = 'cursor'
|
||||||
page_size = api_settings.PAGINATE_BY
|
page_size = api_settings.PAGINATE_BY
|
||||||
invalid_cursor_message = _('Invalid cursor')
|
invalid_cursor_message = _('Invalid cursor')
|
||||||
|
@ -436,7 +437,7 @@ class CursorPagination(BasePagination):
|
||||||
|
|
||||||
def paginate_queryset(self, queryset, request, view=None):
|
def paginate_queryset(self, queryset, request, view=None):
|
||||||
self.base_url = request.build_absolute_uri()
|
self.base_url = request.build_absolute_uri()
|
||||||
self.ordering = self.get_ordering(view)
|
self.ordering = self.get_ordering(request, queryset, view)
|
||||||
|
|
||||||
# Determine if we have a cursor, and if so then decode it.
|
# Determine if we have a cursor, and if so then decode it.
|
||||||
encoded = request.query_params.get(self.cursor_query_param)
|
encoded = request.query_params.get(self.cursor_query_param)
|
||||||
|
@ -600,16 +601,36 @@ class CursorPagination(BasePagination):
|
||||||
encoded = _encode_cursor(cursor)
|
encoded = _encode_cursor(cursor)
|
||||||
return replace_query_param(self.base_url, self.cursor_query_param, encoded)
|
return replace_query_param(self.base_url, self.cursor_query_param, encoded)
|
||||||
|
|
||||||
def get_ordering(self, view):
|
def get_ordering(self, request, queryset, view):
|
||||||
"""
|
"""
|
||||||
Return a tuple of strings, that may be used in an `order_by` method.
|
Return a tuple of strings, that may be used in an `order_by` method.
|
||||||
"""
|
"""
|
||||||
ordering = getattr(view, 'ordering', getattr(self, 'ordering', None))
|
ordering_filters = [
|
||||||
|
filter_cls for filter_cls in getattr(view, 'filter_backends', [])
|
||||||
|
if hasattr(filter_cls, 'get_ordering')
|
||||||
|
]
|
||||||
|
|
||||||
|
if ordering_filters:
|
||||||
|
# If a filter exists on the view that implements `get_ordering`
|
||||||
|
# then we defer to that filter to determine the ordering.
|
||||||
|
filter_cls = ordering_filters[0]
|
||||||
|
filter_instance = filter_cls()
|
||||||
|
ordering = filter_instance.get_ordering(request, queryset, view)
|
||||||
|
assert ordering is not None, (
|
||||||
|
'Using cursor pagination, but filter class {filter_cls} '
|
||||||
|
'returned a `None` ordering.'.format(
|
||||||
|
filter_cls=filter_cls.__name__
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# The default case is to check for an `ordering` attribute,
|
||||||
|
# first on the view instance, and then on this pagination instance.
|
||||||
|
ordering = getattr(view, 'ordering', getattr(self, 'ordering', None))
|
||||||
assert ordering is not None, (
|
assert ordering is not None, (
|
||||||
'Using cursor pagination, but no ordering attribute was declared '
|
'Using cursor pagination, but no ordering attribute was declared '
|
||||||
'on the view or on the pagination class.'
|
'on the view or on the pagination class.'
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(ordering, (six.string_types, list, tuple)), (
|
assert isinstance(ordering, (six.string_types, list, tuple)), (
|
||||||
'Invalid ordering. Expected string or tuple, but got {type}'.format(
|
'Invalid ordering. Expected string or tuple, but got {type}'.format(
|
||||||
type=type(ordering).__name__
|
type=type(ordering).__name__
|
||||||
|
@ -618,7 +639,7 @@ class CursorPagination(BasePagination):
|
||||||
|
|
||||||
if isinstance(ordering, six.string_types):
|
if isinstance(ordering, six.string_types):
|
||||||
return (ordering,)
|
return (ordering,)
|
||||||
return ordering
|
return tuple(ordering)
|
||||||
|
|
||||||
def _get_position_from_instance(self, instance, ordering):
|
def _get_position_from_instance(self, instance, ordering):
|
||||||
attr = getattr(instance, ordering[0])
|
attr = getattr(instance, ordering[0])
|
||||||
|
|
|
@ -77,6 +77,20 @@ class TestPaginationIntegration:
|
||||||
'count': 50
|
'count': 50
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def test_setting_page_size_to_zero(self):
|
||||||
|
"""
|
||||||
|
When page_size parameter is invalid it should return to the default.
|
||||||
|
"""
|
||||||
|
request = factory.get('/', {'page_size': 0})
|
||||||
|
response = self.view(request)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
assert response.data == {
|
||||||
|
'results': [2, 4, 6, 8, 10],
|
||||||
|
'previous': None,
|
||||||
|
'next': 'http://testserver/?page=2&page_size=0',
|
||||||
|
'count': 50
|
||||||
|
}
|
||||||
|
|
||||||
def test_additional_query_params_are_preserved(self):
|
def test_additional_query_params_are_preserved(self):
|
||||||
request = factory.get('/', {'page': 2, 'filter': 'even'})
|
request = factory.get('/', {'page': 2, 'filter': 'even'})
|
||||||
response = self.view(request)
|
response = self.view(request)
|
||||||
|
@ -88,6 +102,14 @@ class TestPaginationIntegration:
|
||||||
'count': 50
|
'count': 50
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def test_404_not_found_for_zero_page(self):
|
||||||
|
request = factory.get('/', {'page': '0'})
|
||||||
|
response = self.view(request)
|
||||||
|
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||||
|
assert response.data == {
|
||||||
|
'detail': 'Invalid page "0": That page number is less than 1.'
|
||||||
|
}
|
||||||
|
|
||||||
def test_404_not_found_for_invalid_page(self):
|
def test_404_not_found_for_invalid_page(self):
|
||||||
request = factory.get('/', {'page': 'invalid'})
|
request = factory.get('/', {'page': 'invalid'})
|
||||||
response = self.view(request)
|
response = self.view(request)
|
||||||
|
@ -507,6 +529,24 @@ class TestCursorPagination:
|
||||||
with pytest.raises(exceptions.NotFound):
|
with pytest.raises(exceptions.NotFound):
|
||||||
self.pagination.paginate_queryset(self.queryset, request)
|
self.pagination.paginate_queryset(self.queryset, request)
|
||||||
|
|
||||||
|
def test_use_with_ordering_filter(self):
|
||||||
|
class MockView:
|
||||||
|
filter_backends = (filters.OrderingFilter,)
|
||||||
|
ordering_fields = ['username', 'created']
|
||||||
|
ordering = 'created'
|
||||||
|
|
||||||
|
request = Request(factory.get('/', {'ordering': 'username'}))
|
||||||
|
ordering = self.pagination.get_ordering(request, [], MockView())
|
||||||
|
assert ordering == ('username',)
|
||||||
|
|
||||||
|
request = Request(factory.get('/', {'ordering': '-username'}))
|
||||||
|
ordering = self.pagination.get_ordering(request, [], MockView())
|
||||||
|
assert ordering == ('-username',)
|
||||||
|
|
||||||
|
request = Request(factory.get('/', {'ordering': 'invalid'}))
|
||||||
|
ordering = self.pagination.get_ordering(request, [], MockView())
|
||||||
|
assert ordering == ('created',)
|
||||||
|
|
||||||
def test_cursor_pagination(self):
|
def test_cursor_pagination(self):
|
||||||
(previous, current, next, previous_url, next_url) = self.get_pages('/')
|
(previous, current, next, previous_url, next_url) = self.get_pages('/')
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user