diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index 5938063af..2fd375500 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -503,15 +503,10 @@ class CursorPagination(BasePagination): self.base_url = request.build_absolute_uri() self.ordering = self.get_ordering(request, queryset, view) - # Determine if we have a cursor, and if so then decode it. - encoded = request.query_params.get(self.cursor_query_param) - if encoded is None: - self.cursor = None + self.cursor = self.decode_cursor(request) + if self.cursor is None: (offset, reverse, current_position) = (0, False, None) else: - self.cursor = _decode_cursor(encoded) - if self.cursor is None: - raise NotFound(self.invalid_cursor_message) (offset, reverse, current_position) = self.cursor # Cursor pagination always enforces an ordering. @@ -623,8 +618,7 @@ class CursorPagination(BasePagination): position = self.previous_position cursor = Cursor(offset=offset, reverse=False, position=position) - encoded = _encode_cursor(cursor) - return replace_query_param(self.base_url, self.cursor_query_param, encoded) + return self.encode_cursor(cursor) def get_previous_link(self): if not self.has_previous: @@ -672,8 +666,7 @@ class CursorPagination(BasePagination): position = self.next_position cursor = Cursor(offset=offset, reverse=True, position=position) - encoded = _encode_cursor(cursor) - return replace_query_param(self.base_url, self.cursor_query_param, encoded) + return self.encode_cursor(cursor) def get_ordering(self, request, queryset, view): """ @@ -715,6 +708,19 @@ class CursorPagination(BasePagination): return (ordering,) return tuple(ordering) + def decode_cursor(self, request): + # Determine if we have a cursor, and if so then decode it. + encoded = request.query_params.get(self.cursor_query_param) + if encoded is not None: + cursor = _decode_cursor(encoded) + if cursor is None: + raise NotFound(self.invalid_cursor_message) + return cursor + + def encode_cursor(self, cursor): + encoded = _encode_cursor(cursor) + return replace_query_param(self.base_url, self.cursor_query_param, encoded) + def _get_position_from_instance(self, instance, ordering): attr = getattr(instance, ordering[0].lstrip('-')) return six.text_type(attr)