django-rest-framework/djangorestframework/tests/mixins.py

240 lines
8.6 KiB
Python
Raw Normal View History

2011-12-09 17:37:53 +04:00
"""Tests for the mixin module"""
from django.test import TestCase
2011-12-09 17:37:53 +04:00
from django.utils import simplejson as json
from djangorestframework import status
from djangorestframework.compat import RequestFactory
from django.contrib.auth.models import Group, User
from djangorestframework.mixins import PaginatorMixin, ModelMixin
from djangorestframework.resources import ModelResource
2011-12-09 17:37:53 +04:00
from djangorestframework.response import Response
from djangorestframework.tests.models import CustomUser
from djangorestframework.tests.testcases import TestModelsTestCase
2011-12-09 17:37:53 +04:00
from djangorestframework.views import View
class TestModelCreation(TestModelsTestCase):
"""Tests on CreateModelMixin"""
def setUp(self):
super(TestModelsTestCase, self).setUp()
self.req = RequestFactory()
def test_creation(self):
self.assertEquals(0, Group.objects.count())
class GroupResource(ModelResource):
model = Group
form_data = {'name': 'foo'}
request = self.req.post('/groups', data=form_data)
mixin = ModelMixin()
mixin.resource = GroupResource
mixin.CONTENT = form_data
2011-12-09 17:37:53 +04:00
response = mixin.create(request)
self.assertEquals(1, Group.objects.count())
self.assertEquals('foo', response.cleaned_content.name)
def test_creation_with_m2m_relation(self):
class UserResource(ModelResource):
model = User
2011-12-09 17:37:53 +04:00
def url(self, instance):
return "/users/%i" % instance.id
group = Group(name='foo')
group.save()
2011-12-09 17:37:53 +04:00
form_data = {
'username': 'bar',
'password': 'baz',
'groups': [group.id]
}
request = self.req.post('/groups', data=form_data)
cleaned_data = dict(form_data)
cleaned_data['groups'] = [group]
mixin = ModelMixin()
mixin.resource = UserResource
mixin.CONTENT = cleaned_data
response = mixin.create(request)
self.assertEquals(1, User.objects.count())
self.assertEquals(1, response.cleaned_content.groups.count())
self.assertEquals('foo', response.cleaned_content.groups.all()[0].name)
2011-12-09 17:37:53 +04:00
def test_creation_with_m2m_relation_through(self):
"""
Tests creation where the m2m relation uses a through table
"""
class UserResource(ModelResource):
model = CustomUser
2011-12-09 17:37:53 +04:00
def url(self, instance):
return "/customusers/%i" % instance.id
2011-12-09 17:37:53 +04:00
form_data = {'username': 'bar0', 'groups': []}
request = self.req.post('/groups', data=form_data)
cleaned_data = dict(form_data)
cleaned_data['groups'] = []
mixin = ModelMixin()
mixin.resource = UserResource
mixin.CONTENT = cleaned_data
response = mixin.create(request)
self.assertEquals(1, CustomUser.objects.count())
2011-12-09 17:37:53 +04:00
self.assertEquals(0, response.cleaned_content.groups.count())
group = Group(name='foo1')
group.save()
2011-12-09 17:37:53 +04:00
form_data = {'username': 'bar1', 'groups': [group.id]}
request = self.req.post('/groups', data=form_data)
cleaned_data = dict(form_data)
cleaned_data['groups'] = [group]
mixin = ModelMixin()
mixin.resource = UserResource
mixin.CONTENT = cleaned_data
response = mixin.create(request)
self.assertEquals(2, CustomUser.objects.count())
self.assertEquals(1, response.cleaned_content.groups.count())
self.assertEquals('foo1', response.cleaned_content.groups.all()[0].name)
2011-12-09 17:37:53 +04:00
group2 = Group(name='foo2')
2011-12-09 17:37:53 +04:00
group2.save()
form_data = {'username': 'bar2', 'groups': [group.id, group2.id]}
request = self.req.post('/groups', data=form_data)
cleaned_data = dict(form_data)
cleaned_data['groups'] = [group, group2]
mixin = ModelMixin()
mixin.resource = UserResource
mixin.CONTENT = cleaned_data
response = mixin.create(request)
self.assertEquals(3, CustomUser.objects.count())
self.assertEquals(2, response.cleaned_content.groups.count())
2011-07-11 23:52:02 +04:00
self.assertEquals('foo1', response.cleaned_content.groups.all()[0].name)
self.assertEquals('foo2', response.cleaned_content.groups.all()[1].name)
2011-12-09 17:37:53 +04:00
class MockPaginatorView(PaginatorMixin, View):
total = 60
def get(self, request):
return range(0, self.total)
def post(self, request):
return Response(status.CREATED, {'status': 'OK'})
class TestPagination(TestCase):
def setUp(self):
self.req = RequestFactory()
def test_default_limit(self):
""" Tests if pagination works without overwriting the limit """
request = self.req.get('/paginator')
response = MockPaginatorView.as_view()(request)
content = json.loads(response.content)
self.assertEqual(response.status_code, status.OK)
self.assertEqual(MockPaginatorView.total, content['total'])
self.assertEqual(MockPaginatorView.limit, content['per_page'])
self.assertEqual(range(0, MockPaginatorView.limit), content['results'])
def test_overwriting_limit(self):
""" Tests if the limit can be overwritten """
limit = 10
request = self.req.get('/paginator')
response = MockPaginatorView.as_view(limit=limit)(request)
content = json.loads(response.content)
self.assertEqual(response.status_code, status.OK)
self.assertEqual(content['per_page'], limit)
self.assertEqual(range(0, limit), content['results'])
def test_limit_param(self):
""" Tests if the client can set the limit """
from math import ceil
limit = 5
num_pages = int(ceil(MockPaginatorView.total / float(limit)))
request = self.req.get('/paginator/?limit=%d' % limit)
response = MockPaginatorView.as_view()(request)
content = json.loads(response.content)
self.assertEqual(response.status_code, status.OK)
self.assertEqual(MockPaginatorView.total, content['total'])
self.assertEqual(limit, content['per_page'])
self.assertEqual(num_pages, content['pages'])
def test_exceeding_limit(self):
""" Makes sure the client cannot exceed the default limit """
from math import ceil
limit = MockPaginatorView.limit + 10
num_pages = int(ceil(MockPaginatorView.total / float(limit)))
request = self.req.get('/paginator/?limit=%d' % limit)
response = MockPaginatorView.as_view()(request)
content = json.loads(response.content)
self.assertEqual(response.status_code, status.OK)
self.assertEqual(MockPaginatorView.total, content['total'])
self.assertNotEqual(limit, content['per_page'])
self.assertNotEqual(num_pages, content['pages'])
self.assertEqual(MockPaginatorView.limit, content['per_page'])
def test_only_works_for_get(self):
""" Pagination should only work for GET requests """
request = self.req.post('/paginator', data={'content': 'spam'})
response = MockPaginatorView.as_view()(request)
content = json.loads(response.content)
self.assertEqual(response.status_code, status.CREATED)
self.assertEqual(None, content.get('per_page'))
self.assertEqual('OK', content['status'])
def test_non_int_page(self):
""" Tests that it can handle invalid values """
request = self.req.get('/paginator/?page=spam')
response = MockPaginatorView.as_view()(request)
self.assertEqual(response.status_code, status.NOT_FOUND)
def test_page_range(self):
""" Tests that the page range is handle correctly """
request = self.req.get('/paginator/?page=0')
response = MockPaginatorView.as_view()(request)
content = json.loads(response.content)
self.assertEqual(response.status_code, status.NOT_FOUND)
request = self.req.get('/paginator/')
response = MockPaginatorView.as_view()(request)
content = json.loads(response.content)
self.assertEqual(response.status_code, status.OK)
self.assertEqual(range(0, MockPaginatorView.limit), content['results'])
num_pages = content['pages']
request = self.req.get('/paginator/?page=%d' % num_pages)
response = MockPaginatorView.as_view()(request)
content = json.loads(response.content)
self.assertEqual(response.status_code, status.OK)
self.assertEqual(range(MockPaginatorView.limit*(num_pages-1), MockPaginatorView.total), content['results'])
request = self.req.get('/paginator/?page=%d' % (num_pages + 1,))
response = MockPaginatorView.as_view()(request)
content = json.loads(response.content)
self.assertEqual(response.status_code, status.NOT_FOUND)