Pagination support

This commit is contained in:
Tom Christie 2012-09-30 17:31:28 +01:00
parent 43d3634e89
commit 6fa589fefd
7 changed files with 151 additions and 14 deletions

View File

@ -139,7 +139,13 @@ class Field(object):
if hasattr(self, 'model_field'):
return self.to_native(self.model_field._get_val_from_obj(obj))
return self.to_native(getattr(obj, self.source or field_name))
if self.source:
value = obj
for component in self.source.split('.'):
value = getattr(value, component)
else:
value = getattr(obj, field_name)
return self.to_native(value)
def to_native(self, value):
"""
@ -175,7 +181,7 @@ class RelatedField(Field):
"""
def field_to_native(self, obj, field_name):
obj = getattr(obj, field_name)
obj = getattr(obj, self.source or field_name)
if obj.__class__.__name__ in ('RelatedManager', 'ManyRelatedManager'):
return [self.to_native(item) for item in obj.all()]
return self.to_native(obj)
@ -215,10 +221,10 @@ class PrimaryKeyRelatedField(RelatedField):
def field_to_native(self, obj, field_name):
try:
obj = obj.serializable_value(field_name)
obj = obj.serializable_value(self.source or field_name)
except AttributeError:
field = obj._meta.get_field_by_name(field_name)[0]
obj = getattr(obj, field_name)
obj = getattr(obj, self.source or field_name)
if obj.__class__.__name__ == 'RelatedManager':
return [self.to_native(item.pk) for item in obj.all()]
elif isinstance(field, RelatedObject):

View File

@ -2,7 +2,8 @@
Generic views that provide commmonly needed behaviour.
"""
from rest_framework import views, mixins, serializers
from rest_framework import views, mixins
from rest_framework.settings import api_settings
from django.views.generic.detail import SingleObjectMixin
from django.views.generic.list import MultipleObjectMixin
@ -14,23 +15,37 @@ class BaseView(views.APIView):
Base class for all other generic views.
"""
serializer_class = None
model_serializer_class = api_settings.MODEL_SERIALIZER
pagination_serializer_class = api_settings.PAGINATION_SERIALIZER
paginate_by = api_settings.PAGINATE_BY
def get_serializer(self, data=None, files=None, instance=None):
def get_serializer_context(self):
return {
'request': self.request,
'format': self.kwargs.get('format', None)
}
def get_serializer(self, data=None, files=None, instance=None, kwargs=None):
# TODO: add support for files
# TODO: add support for seperate serializer/deserializer
serializer_class = self.serializer_class
kwargs = kwargs or {}
if serializer_class is None:
class DefaultSerializer(serializers.ModelSerializer):
class DefaultSerializer(self.model_serializer_class):
class Meta:
model = self.model
serializer_class = DefaultSerializer
context = {
'request': self.request,
'format': self.kwargs.get('format', None)
}
return serializer_class(data, instance=instance, context=context)
context = self.get_serializer_context()
return serializer_class(data, instance=instance, context=context, **kwargs)
def get_pagination_serializer(self, page=None):
serializer_class = self.pagination_serializer_class
context = self.get_serializer_context()
ret = serializer_class(instance=page, context=context)
ret.fields['results'] = self.get_serializer(kwargs={'source': 'object_list'})
return ret
class MultipleObjectBaseView(MultipleObjectMixin, BaseView):

View File

@ -7,6 +7,7 @@ which allows mixin classes to be composed in interesting ways.
Eg. Use mixins to build a Resource class, and have a Router class
perform the binding of http methods to actions for us.
"""
from django.http import Http404
from rest_framework import status
from rest_framework.response import Response
@ -30,9 +31,27 @@ class ListModelMixin(object):
List a queryset.
Should be mixed in with `MultipleObjectBaseView`.
"""
empty_error = u"Empty list and '%(class_name)s.allow_empty' is False."
def list(self, request, *args, **kwargs):
self.object_list = self.get_queryset()
serializer = self.get_serializer(instance=self.object_list)
# Default is to allow empty querysets. This can be altered by setting
# `.allow_empty = False`, to raise 404 errors on empty querysets.
allow_empty = self.get_allow_empty()
if not allow_empty and len(self.object_list) == 0:
error_args = {'class_name': self.__class__.__name__}
raise Http404(self.empty_error % error_args)
# Pagination size is set by the `.paginate_by` attribute,
# which may be `None` to disable pagination.
page_size = self.get_paginate_by(self.object_list)
if page_size:
paginator, page, queryset, is_paginated = self.paginate_queryset(self.object_list, page_size)
serializer = self.get_pagination_serializer(page)
else:
serializer = self.get_serializer(instance=self.object_list)
return Response(serializer.data)

View File

@ -0,0 +1,34 @@
from rest_framework import serializers
# TODO: Support URLconf kwarg-style paging
class NextPageField(serializers.Field):
def to_native(self, value):
if not value.has_next():
return None
page = value.next_page_number()
request = self.context['request']
return request.build_absolute_uri('?page=%d' % page)
class PreviousPageField(serializers.Field):
def to_native(self, value):
if not value.has_previous():
return None
page = value.previous_page_number()
request = self.context['request']
return request.build_absolute_uri('?page=%d' % page)
class PaginationSerializer(serializers.Serializer):
count = serializers.Field(source='paginator.count')
next = NextPageField(source='*')
previous = PreviousPageField(source='*')
def to_native(self, obj):
"""
Prevent default behaviour of iterating over elements, and serializing
each in turn.
"""
return self.convert_object(obj)

View File

@ -44,6 +44,10 @@ DEFAULTS = {
'anon': None,
},
'MODEL_SERIALIZER': 'rest_framework.serializers.ModelSerializer',
'PAGINATION_SERIALIZER': 'rest_framework.pagination.PaginationSerializer',
'PAGINATE_BY': 20,
'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',
'UNAUTHENTICATED_TOKEN': None,
@ -65,6 +69,8 @@ IMPORT_STRINGS = (
'DEFAULT_PERMISSIONS',
'DEFAULT_THROTTLES',
'DEFAULT_CONTENT_NEGOTIATION',
'MODEL_SERIALIZER',
'PAGINATION_SERIALIZER',
'UNAUTHENTICATED_USER',
'UNAUTHENTICATED_TOKEN',
)

View File

@ -1,5 +1,5 @@
from django import template
from django.core.urlresolvers import reverse, NoReverseMatch
from django.core.urlresolvers import reverse
from django.http import QueryDict
from django.utils.encoding import force_unicode
from django.utils.html import escape

View File

@ -0,0 +1,57 @@
from django.test import TestCase
from django.test.client import RequestFactory
from rest_framework import generics, status
from rest_framework.tests.models import BasicModel
factory = RequestFactory()
class RootView(generics.RootAPIView):
"""
Example description for OPTIONS.
"""
model = BasicModel
paginate_by = 10
class TestPaginatedView(TestCase):
def setUp(self):
"""
Create 26 BasicModel intances.
"""
for char in 'abcdefghijklmnopqrstuvwxyz':
BasicModel(text=char * 3).save()
self.objects = BasicModel.objects
self.data = [
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
]
self.view = RootView.as_view()
def test_get_paginated_root_view(self):
"""
GET requests to paginated RootAPIView should return paginated results.
"""
request = factory.get('/')
response = self.view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data['count'], 26)
self.assertEquals(response.data['results'], self.data[:10])
self.assertNotEquals(response.data['next'], None)
self.assertEquals(response.data['previous'], None)
request = factory.get(response.data['next'])
response = self.view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data['count'], 26)
self.assertEquals(response.data['results'], self.data[10:20])
self.assertNotEquals(response.data['next'], None)
self.assertNotEquals(response.data['previous'], None)
request = factory.get(response.data['next'])
response = self.view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data['count'], 26)
self.assertEquals(response.data['results'], self.data[20:])
self.assertEquals(response.data['next'], None)
self.assertNotEquals(response.data['previous'], None)