From 6fa589fefd48d98e4f0a11548b6c3e5ced58e31e Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Sun, 30 Sep 2012 17:31:28 +0100 Subject: [PATCH] Pagination support --- rest_framework/fields.py | 14 +++-- rest_framework/generics.py | 31 +++++++--- rest_framework/mixins.py | 21 ++++++- rest_framework/pagination.py | 34 +++++++++++ rest_framework/settings.py | 6 ++ rest_framework/templatetags/rest_framework.py | 2 +- rest_framework/tests/pagination.py | 57 +++++++++++++++++++ 7 files changed, 151 insertions(+), 14 deletions(-) create mode 100644 rest_framework/pagination.py create mode 100644 rest_framework/tests/pagination.py diff --git a/rest_framework/fields.py b/rest_framework/fields.py index eab90617c..74675ee9f 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -139,7 +139,13 @@ class Field(object): if hasattr(self, 'model_field'): return self.to_native(self.model_field._get_val_from_obj(obj)) - return self.to_native(getattr(obj, self.source or field_name)) + if self.source: + value = obj + for component in self.source.split('.'): + value = getattr(value, component) + else: + value = getattr(obj, field_name) + return self.to_native(value) def to_native(self, value): """ @@ -175,7 +181,7 @@ class RelatedField(Field): """ def field_to_native(self, obj, field_name): - obj = getattr(obj, field_name) + obj = getattr(obj, self.source or field_name) if obj.__class__.__name__ in ('RelatedManager', 'ManyRelatedManager'): return [self.to_native(item) for item in obj.all()] return self.to_native(obj) @@ -215,10 +221,10 @@ class PrimaryKeyRelatedField(RelatedField): def field_to_native(self, obj, field_name): try: - obj = obj.serializable_value(field_name) + obj = obj.serializable_value(self.source or field_name) except AttributeError: field = obj._meta.get_field_by_name(field_name)[0] - obj = getattr(obj, field_name) + obj = getattr(obj, self.source or field_name) if obj.__class__.__name__ == 'RelatedManager': return [self.to_native(item.pk) for item in obj.all()] elif isinstance(field, RelatedObject): diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 4240e33e4..1e547b32e 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -2,7 +2,8 @@ Generic views that provide commmonly needed behaviour. """ -from rest_framework import views, mixins, serializers +from rest_framework import views, mixins +from rest_framework.settings import api_settings from django.views.generic.detail import SingleObjectMixin from django.views.generic.list import MultipleObjectMixin @@ -14,23 +15,37 @@ class BaseView(views.APIView): Base class for all other generic views. """ serializer_class = None + model_serializer_class = api_settings.MODEL_SERIALIZER + pagination_serializer_class = api_settings.PAGINATION_SERIALIZER + paginate_by = api_settings.PAGINATE_BY - def get_serializer(self, data=None, files=None, instance=None): + def get_serializer_context(self): + return { + 'request': self.request, + 'format': self.kwargs.get('format', None) + } + + def get_serializer(self, data=None, files=None, instance=None, kwargs=None): # TODO: add support for files # TODO: add support for seperate serializer/deserializer serializer_class = self.serializer_class + kwargs = kwargs or {} if serializer_class is None: - class DefaultSerializer(serializers.ModelSerializer): + class DefaultSerializer(self.model_serializer_class): class Meta: model = self.model serializer_class = DefaultSerializer - context = { - 'request': self.request, - 'format': self.kwargs.get('format', None) - } - return serializer_class(data, instance=instance, context=context) + context = self.get_serializer_context() + return serializer_class(data, instance=instance, context=context, **kwargs) + + def get_pagination_serializer(self, page=None): + serializer_class = self.pagination_serializer_class + context = self.get_serializer_context() + ret = serializer_class(instance=page, context=context) + ret.fields['results'] = self.get_serializer(kwargs={'source': 'object_list'}) + return ret class MultipleObjectBaseView(MultipleObjectMixin, BaseView): diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index fe12dc8ff..167cd89aa 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -7,6 +7,7 @@ which allows mixin classes to be composed in interesting ways. Eg. Use mixins to build a Resource class, and have a Router class perform the binding of http methods to actions for us. """ +from django.http import Http404 from rest_framework import status from rest_framework.response import Response @@ -30,9 +31,27 @@ class ListModelMixin(object): List a queryset. Should be mixed in with `MultipleObjectBaseView`. """ + empty_error = u"Empty list and '%(class_name)s.allow_empty' is False." + def list(self, request, *args, **kwargs): self.object_list = self.get_queryset() - serializer = self.get_serializer(instance=self.object_list) + + # Default is to allow empty querysets. This can be altered by setting + # `.allow_empty = False`, to raise 404 errors on empty querysets. + allow_empty = self.get_allow_empty() + if not allow_empty and len(self.object_list) == 0: + error_args = {'class_name': self.__class__.__name__} + raise Http404(self.empty_error % error_args) + + # Pagination size is set by the `.paginate_by` attribute, + # which may be `None` to disable pagination. + page_size = self.get_paginate_by(self.object_list) + if page_size: + paginator, page, queryset, is_paginated = self.paginate_queryset(self.object_list, page_size) + serializer = self.get_pagination_serializer(page) + else: + serializer = self.get_serializer(instance=self.object_list) + return Response(serializer.data) diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py new file mode 100644 index 000000000..398e6f3d5 --- /dev/null +++ b/rest_framework/pagination.py @@ -0,0 +1,34 @@ +from rest_framework import serializers + +# TODO: Support URLconf kwarg-style paging + + +class NextPageField(serializers.Field): + def to_native(self, value): + if not value.has_next(): + return None + page = value.next_page_number() + request = self.context['request'] + return request.build_absolute_uri('?page=%d' % page) + + +class PreviousPageField(serializers.Field): + def to_native(self, value): + if not value.has_previous(): + return None + page = value.previous_page_number() + request = self.context['request'] + return request.build_absolute_uri('?page=%d' % page) + + +class PaginationSerializer(serializers.Serializer): + count = serializers.Field(source='paginator.count') + next = NextPageField(source='*') + previous = PreviousPageField(source='*') + + def to_native(self, obj): + """ + Prevent default behaviour of iterating over elements, and serializing + each in turn. + """ + return self.convert_object(obj) diff --git a/rest_framework/settings.py b/rest_framework/settings.py index cfc89fe1a..8387fd294 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -44,6 +44,10 @@ DEFAULTS = { 'anon': None, }, + 'MODEL_SERIALIZER': 'rest_framework.serializers.ModelSerializer', + 'PAGINATION_SERIALIZER': 'rest_framework.pagination.PaginationSerializer', + 'PAGINATE_BY': 20, + 'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser', 'UNAUTHENTICATED_TOKEN': None, @@ -65,6 +69,8 @@ IMPORT_STRINGS = ( 'DEFAULT_PERMISSIONS', 'DEFAULT_THROTTLES', 'DEFAULT_CONTENT_NEGOTIATION', + 'MODEL_SERIALIZER', + 'PAGINATION_SERIALIZER', 'UNAUTHENTICATED_USER', 'UNAUTHENTICATED_TOKEN', ) diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py index 377fd489b..c9b6eb10d 100644 --- a/rest_framework/templatetags/rest_framework.py +++ b/rest_framework/templatetags/rest_framework.py @@ -1,5 +1,5 @@ from django import template -from django.core.urlresolvers import reverse, NoReverseMatch +from django.core.urlresolvers import reverse from django.http import QueryDict from django.utils.encoding import force_unicode from django.utils.html import escape diff --git a/rest_framework/tests/pagination.py b/rest_framework/tests/pagination.py new file mode 100644 index 000000000..4ddfc9157 --- /dev/null +++ b/rest_framework/tests/pagination.py @@ -0,0 +1,57 @@ +from django.test import TestCase +from django.test.client import RequestFactory +from rest_framework import generics, status +from rest_framework.tests.models import BasicModel + +factory = RequestFactory() + + +class RootView(generics.RootAPIView): + """ + Example description for OPTIONS. + """ + model = BasicModel + paginate_by = 10 + + +class TestPaginatedView(TestCase): + def setUp(self): + """ + Create 26 BasicModel intances. + """ + for char in 'abcdefghijklmnopqrstuvwxyz': + BasicModel(text=char * 3).save() + self.objects = BasicModel.objects + self.data = [ + {'id': obj.id, 'text': obj.text} + for obj in self.objects.all() + ] + self.view = RootView.as_view() + + def test_get_paginated_root_view(self): + """ + GET requests to paginated RootAPIView should return paginated results. + """ + request = factory.get('/') + response = self.view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEquals(response.data['count'], 26) + self.assertEquals(response.data['results'], self.data[:10]) + self.assertNotEquals(response.data['next'], None) + self.assertEquals(response.data['previous'], None) + + request = factory.get(response.data['next']) + response = self.view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEquals(response.data['count'], 26) + self.assertEquals(response.data['results'], self.data[10:20]) + self.assertNotEquals(response.data['next'], None) + self.assertNotEquals(response.data['previous'], None) + + request = factory.get(response.data['next']) + response = self.view(request).render() + self.assertEquals(response.status_code, status.HTTP_200_OK) + self.assertEquals(response.data['count'], 26) + self.assertEquals(response.data['results'], self.data[20:]) + self.assertEquals(response.data['next'], None) + self.assertNotEquals(response.data['previous'], None)