From 228406e64a744fd2ae171d361a74116fcc7ed003 Mon Sep 17 00:00:00 2001 From: no-dap Date: Sat, 21 Sep 2019 15:31:03 +0900 Subject: [PATCH] Adding paginate decorator --- rest_framework/decorators.py | 29 +++++++++++++++++++++ tests/test_decorators.py | 49 +++++++++++++++++++++++++++++------- 2 files changed, 69 insertions(+), 9 deletions(-) diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index eb1cad9e4..e1f2341e2 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -212,3 +212,32 @@ class MethodMapper(dict): def trace(self, func): return self._map('trace', func) + + +def paginate(pagination_class=None, **kwargs): + """ + Decorator that adds a pagination_class to GenericViewSet class. + Custom pagination class also available. + + Usage : + from rest_framework.pagination import CursorPagination + + @paginate(pagination_class=CursorPagination, page_size=5, ordering='-created_at') + class FooViewSet(viewsets.GenericViewSet): + ... + + """ + assert pagination_class is not None, ( + "@paginate missing required argument: 'pagination_class'" + ) + + class _Pagination(pagination_class): + def __init__(self): + self.__dict__.update(kwargs) + super(_Pagination, self).__init__() + + def decorator(_class): + _class.pagination_class = _Pagination + return _class + + return decorator diff --git a/tests/test_decorators.py b/tests/test_decorators.py index e10f0e5c5..af6f2af9c 100644 --- a/tests/test_decorators.py +++ b/tests/test_decorators.py @@ -1,12 +1,13 @@ import pytest from django.test import TestCase -from rest_framework import status +from rest_framework import pagination, status, viewsets from rest_framework.authentication import BasicAuthentication from rest_framework.decorators import ( - action, api_view, authentication_classes, parser_classes, + action, api_view, authentication_classes, paginate, parser_classes, permission_classes, renderer_classes, schema, throttle_classes ) +from rest_framework.pagination import BasePagination from rest_framework.parsers import JSONParser from rest_framework.permissions import IsAuthenticated from rest_framework.renderers import JSONRenderer @@ -49,7 +50,6 @@ class DecoratorTestCase(TestCase): return Response() def test_calling_method(self): - @api_view(['GET']) def view(request): return Response({}) @@ -63,7 +63,6 @@ class DecoratorTestCase(TestCase): assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED def test_calling_put_method(self): - @api_view(['GET', 'PUT']) def view(request): return Response({}) @@ -77,7 +76,6 @@ class DecoratorTestCase(TestCase): assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED def test_calling_patch_method(self): - @api_view(['GET', 'PATCH']) def view(request): return Response({}) @@ -91,7 +89,6 @@ class DecoratorTestCase(TestCase): assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED def test_renderer_classes(self): - @api_view(['GET']) @renderer_classes([JSONRenderer]) def view(request): @@ -102,7 +99,6 @@ class DecoratorTestCase(TestCase): assert isinstance(response.accepted_renderer, JSONRenderer) def test_parser_classes(self): - @api_view(['GET']) @parser_classes([JSONParser]) def view(request): @@ -114,7 +110,6 @@ class DecoratorTestCase(TestCase): view(request) def test_authentication_classes(self): - @api_view(['GET']) @authentication_classes([BasicAuthentication]) def view(request): @@ -126,7 +121,6 @@ class DecoratorTestCase(TestCase): view(request) def test_permission_classes(self): - @api_view(['GET']) @permission_classes([IsAuthenticated]) def view(request): @@ -156,6 +150,7 @@ class DecoratorTestCase(TestCase): """ Checks CustomSchema class is set on view """ + class CustomSchema(AutoSchema): pass @@ -213,6 +208,7 @@ class ActionDecoratorTestCase(TestCase): 'name' and 'suffix' are mutually exclusive kwargs used for generating a view's display name. """ + # by default, generate name from method @action(detail=True) def test_action(request): @@ -284,3 +280,38 @@ class ActionDecoratorTestCase(TestCase): @test_action.mapping.post def test_action(): raise NotImplementedError + + +class TestPaginateDecorator(TestCase): + + def test_empty_pagination_class(self): + msg = "@paginate missing required argument: 'pagination_class'" + with self.assertRaisesMessage(AssertionError, msg): + @paginate() + class MockGenericViewSet(viewsets.GenericViewSet): + pass + + def test_adding_page_number_pagination(self): + """ + Other default pagination classes' test result will be same as this even if kwargs changed to anything. + """ + + @paginate(pagination_class=pagination.PageNumberPagination, page_size=5, ordering='-created_at') + class MockGenericViewSet(viewsets.GenericViewSet): + pass + + assert hasattr(MockGenericViewSet, 'pagination_class') + assert MockGenericViewSet.pagination_class().page_size == 5 + assert MockGenericViewSet.pagination_class().ordering == '-created_at' + + def test_adding_custom_pagination(self): + class CustomPagination(BasePagination): + pass + + @paginate(pagination_class=CustomPagination, kwarg1='kwarg1', kwarg2='kwarg2') + class MockGenericViewSet(viewsets.GenericViewSet): + pass + + assert hasattr(MockGenericViewSet, 'pagination_class') + assert MockGenericViewSet.pagination_class().kwarg1 == 'kwarg1' + assert MockGenericViewSet.pagination_class().kwarg2 == 'kwarg2'