diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index 100b31b71..87e154c75 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -568,7 +568,7 @@ class CursorPagination(BasePagination): def get_page_size(self, request): return self.page_size - def get_next_link(self): + def get_next_cursor(self): if not self.has_next: return None @@ -613,10 +613,18 @@ class CursorPagination(BasePagination): offset = self.cursor.offset + self.page_size position = self.previous_position - cursor = Cursor(offset=offset, reverse=False, position=position) - return self.encode_cursor(cursor) + return Cursor(offset=offset, reverse=False, position=position) - def get_previous_link(self): + def get_next_encoded_cursor(self): + return self.encode_cursor(self.get_next_cursor()) + + def get_next_link(self): + if not self.has_next: + return None + + return self.urify_cursor(self.get_next_encoded_cursor()) + + def get_previous_cursor(self): if not self.has_previous: return None @@ -661,8 +669,16 @@ class CursorPagination(BasePagination): offset = 0 position = self.next_position - cursor = Cursor(offset=offset, reverse=True, position=position) - return self.encode_cursor(cursor) + return Cursor(offset=offset, reverse=True, position=position) + + def get_previous_encoded_cursor(self): + return self.encode_cursor(self.get_previous_cursor()) + + def get_previous_link(self): + if not self.has_previous: + return None + + return self.urify_cursor(self.get_previous_encoded_cursor()) def get_ordering(self, request, queryset, view): """ @@ -738,6 +754,9 @@ class CursorPagination(BasePagination): """ Given a Cursor instance, return an url with encoded cursor. """ + if cursor is None: + return None + tokens = {} if cursor.offset != 0: tokens['o'] = str(cursor.offset) @@ -747,8 +766,10 @@ class CursorPagination(BasePagination): tokens['p'] = cursor.position querystring = urlparse.urlencode(tokens, doseq=True) - encoded = b64encode(querystring.encode('ascii')).decode('ascii') - return replace_query_param(self.base_url, self.cursor_query_param, encoded) + return b64encode(querystring.encode('ascii')).decode('ascii') + + def urify_cursor(self, encoded_cursor): + return replace_query_param(self.base_url, self.cursor_query_param, encoded_cursor) def _get_position_from_instance(self, instance, ordering): field_name = ordering[0].lstrip('-')