From b3b74b0dfe7a07fa65061dbe3aa6308a2696f3a6 Mon Sep 17 00:00:00 2001 From: Jeff Schwaber Date: Fri, 26 May 2017 12:35:25 -0700 Subject: [PATCH] pass the queryset through to the Pagination object so that it can be used by get_paginated_response() to provide a count of the entire queryset, even if at that point the only data available to respond with is a (page) subset of the results. --- rest_framework/generics.py | 4 ++-- rest_framework/mixins.py | 2 +- rest_framework/pagination.py | 6 +++--- tests/test_pagination.py | 6 +++--- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 8d0bf284a..0ddc368e6 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -172,12 +172,12 @@ class GenericAPIView(views.APIView): return None return self.paginator.paginate_queryset(queryset, self.request, view=self) - def get_paginated_response(self, data): + def get_paginated_response(self, data, queryset): """ Return a paginated style `Response` object for the given output data. """ assert self.paginator is not None - return self.paginator.get_paginated_response(data) + return self.paginator.get_paginated_response(data, queryset) # Concrete view classes that provide method handlers diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index f3695e665..783dcc627 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -42,7 +42,7 @@ class ListModelMixin(object): page = self.paginate_queryset(queryset) if page is not None: serializer = self.get_serializer(page, many=True) - return self.get_paginated_response(serializer.data) + return self.get_paginated_response(serializer.data, queryset) serializer = self.get_serializer(queryset, many=True) return Response(serializer.data) diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index 0255cfc7f..4193ac232 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -225,7 +225,7 @@ class PageNumberPagination(BasePagination): self.request = request return list(self.page) - def get_paginated_response(self, data): + def get_paginated_response(self, data, queryset): return Response(OrderedDict([ ('count', self.page.paginator.count), ('next', self.get_next_link()), @@ -346,7 +346,7 @@ class LimitOffsetPagination(BasePagination): return [] return list(queryset[self.offset:self.offset + self.limit]) - def get_paginated_response(self, data): + def get_paginated_response(self, data, queryset): return Response(OrderedDict([ ('count', self.count), ('next', self.get_next_link()), @@ -758,7 +758,7 @@ class CursorPagination(BasePagination): attr = getattr(instance, field_name) return six.text_type(attr) - def get_paginated_response(self, data): + def get_paginated_response(self, data, queryset): return Response(OrderedDict([ ('next', self.get_next_link()), ('previous', self.get_previous_link()), diff --git a/tests/test_pagination.py b/tests/test_pagination.py index dd7f70330..39db63a4e 100644 --- a/tests/test_pagination.py +++ b/tests/test_pagination.py @@ -177,7 +177,7 @@ class TestPageNumberPagination: return list(self.pagination.paginate_queryset(self.queryset, request)) def get_paginated_content(self, queryset): - response = self.pagination.get_paginated_response(queryset) + response = self.pagination.get_paginated_response(queryset, queryset) return response.data def get_html_context(self): @@ -287,7 +287,7 @@ class TestPageNumberPaginationOverride: return list(self.pagination.paginate_queryset(self.queryset, request)) def get_paginated_content(self, queryset): - response = self.pagination.get_paginated_response(queryset) + response = self.pagination.get_paginated_response(queryset, queryset) return response.data def get_html_context(self): @@ -338,7 +338,7 @@ class TestLimitOffset: return list(self.pagination.paginate_queryset(self.queryset, request)) def get_paginated_content(self, queryset): - response = self.pagination.get_paginated_response(queryset) + response = self.pagination.get_paginated_response(queryset, queryset) return response.data def get_html_context(self):