Merge pull request #1375 from linovia/feature/django_1_7

Django 1.7 compatibility
This commit is contained in:
Tom Christie 2014-03-03 14:32:38 +00:00
commit 24a6882232
12 changed files with 134 additions and 69 deletions

View File

@ -7,6 +7,7 @@ python:
- "3.3" - "3.3"
env: env:
- DJANGO="https://www.djangoproject.com/download/1.7a2/tarball/"
- DJANGO="django==1.6.2" - DJANGO="django==1.6.2"
- DJANGO="django==1.5.5" - DJANGO="django==1.5.5"
- DJANGO="django==1.4.10" - DJANGO="django==1.4.10"
@ -14,13 +15,15 @@ env:
install: install:
- pip install $DJANGO - pip install $DJANGO
- pip install defusedxml==0.3 - pip install defusedxml==0.3 Pillow
- "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install oauth2==1.5.211; fi" - "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install oauth2==1.5.211; fi"
- "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install django-oauth-plus==2.2.1; fi" - "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install django-oauth-plus==2.2.4; fi"
- "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install django-oauth2-provider==0.2.4; fi" - "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install django-oauth2-provider==0.2.4; fi"
- "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install django-guardian==1.1.1; fi" - "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install django-guardian==1.1.1; fi"
- "if [[ ${DJANGO::11} == 'django==1.3' ]]; then pip install django-filter==0.5.4; fi" - "if [[ ${DJANGO::11} == 'django==1.3' ]]; then pip install django-filter==0.5.4; fi"
- "if [[ ${DJANGO::11} != 'django==1.3' ]]; then pip install django-filter==0.6; fi" - "if [[ ${DJANGO::11} != 'django==1.3' ]]; then pip install django-filter==0.7; fi"
- "if [[ ${TRAVIS_PYTHON_VERSION::1} == '3' ]]; then pip install -e git+https://github.com/linovia/django-guardian.git@feature/django_1_7#egg=django-guardian-1.2.0; fi"
- "if [[ ${DJANGO} == 'https://www.djangoproject.com/download/1.7a2/tarball/' ]]; then pip install -e git+https://github.com/linovia/django-guardian.git@feature/django_1_7#egg=django-guardian-1.2.0; fi"
- export PYTHONPATH=. - export PYTHONPATH=.
script: script:
@ -28,6 +31,8 @@ script:
matrix: matrix:
exclude: exclude:
- python: "2.6"
env: DJANGO="https://www.djangoproject.com/download/1.7a2/tarball/"
- python: "3.2" - python: "3.2"
env: DJANGO="django==1.4.10" env: DJANGO="django==1.4.10"
- python: "3.2" - python: "3.2"

View File

@ -584,3 +584,23 @@ if six.PY3:
else: else:
def is_non_str_iterable(obj): def is_non_str_iterable(obj):
return hasattr(obj, '__iter__') return hasattr(obj, '__iter__')
try:
from django.utils.encoding import python_2_unicode_compatible
except ImportError:
def python_2_unicode_compatible(klass):
"""
A decorator that defines __unicode__ and __str__ methods under Python 2.
Under Python 3 it does nothing.
To support Python 2 and 3 with a single code base, define a __str__ method
returning text and apply this decorator to the class.
"""
if '__str__' not in klass.__dict__:
raise ValueError("@python_2_unicode_compatible cannot be applied "
"to %s because it doesn't define __str__()." %
klass.__name__)
klass.__unicode__ = klass.__str__
klass.__str__ = lambda self: self.__unicode__().encode('utf-8')
return klass

View File

@ -26,6 +26,10 @@ def usage():
def main(): def main():
try:
django.setup()
except AttributeError:
pass
TestRunner = get_runner(settings) TestRunner = get_runner(settings)
test_runner = TestRunner() test_runner = TestRunner()

View File

@ -8,6 +8,7 @@ from django.conf import settings
from django.test.client import Client as DjangoClient from django.test.client import Client as DjangoClient
from django.test.client import ClientHandler from django.test.client import ClientHandler
from django.test import testcases from django.test import testcases
from django.utils.http import urlencode
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.compat import RequestFactory as DjangoRequestFactory from rest_framework.compat import RequestFactory as DjangoRequestFactory
from rest_framework.compat import force_bytes_or_smart_bytes, six from rest_framework.compat import force_bytes_or_smart_bytes, six
@ -71,6 +72,13 @@ class APIRequestFactory(DjangoRequestFactory):
return ret, content_type return ret, content_type
def get(self, path, data=None, **extra):
r = {
'QUERY_STRING': urlencode(data or {}, doseq=True),
}
r.update(extra)
return self.generic('GET', path, **r)
def post(self, path, data=None, format=None, content_type=None, **extra): def post(self, path, data=None, format=None, content_type=None, **extra):
data, content_type = self._encode_data(data, format, content_type) data, content_type = self._encode_data(data, format, content_type)
return self.generic('POST', path, data, content_type, **extra) return self.generic('POST', path, data, content_type, **extra)

View File

@ -168,3 +168,10 @@ class NullableOneToOneSource(RESTFrameworkModel):
class BasicModelSerializer(serializers.ModelSerializer): class BasicModelSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = BasicModel model = BasicModel
# Models to test filters
class FilterableItem(models.Model):
text = models.CharField(max_length=100)
decimal = models.DecimalField(max_digits=4, decimal_places=2)
date = models.DateField()

View File

@ -9,16 +9,11 @@ from rest_framework import generics, serializers, status, filters
from rest_framework.compat import django_filters, patterns, url from rest_framework.compat import django_filters, patterns, url
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
from rest_framework.tests.models import BasicModel from rest_framework.tests.models import BasicModel
from .models import FilterableItem
factory = APIRequestFactory() factory = APIRequestFactory()
class FilterableItem(models.Model):
text = models.CharField(max_length=100)
decimal = models.DecimalField(max_digits=4, decimal_places=2)
date = models.DateField()
if django_filters: if django_filters:
# Basic filter on a list view. # Basic filter on a list view.
class FilterFieldsRootView(generics.ListCreateAPIView): class FilterFieldsRootView(generics.ListCreateAPIView):
@ -128,7 +123,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
# Tests that the decimal filter works. # Tests that the decimal filter works.
search_decimal = Decimal('2.25') search_decimal = Decimal('2.25')
request = factory.get('/?decimal=%s' % search_decimal) request = factory.get('/', {'decimal': '%s' % search_decimal})
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['decimal'] == search_decimal] expected_data = [f for f in self.data if f['decimal'] == search_decimal]
@ -136,7 +131,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
# Tests that the date filter works. # Tests that the date filter works.
search_date = datetime.date(2012, 9, 22) search_date = datetime.date(2012, 9, 22)
request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22' request = factory.get('/', {'date': '%s' % search_date}) # search_date str: '2012-09-22'
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['date'] == search_date] expected_data = [f for f in self.data if f['date'] == search_date]
@ -151,7 +146,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
# Tests that the decimal filter works. # Tests that the decimal filter works.
search_decimal = Decimal('2.25') search_decimal = Decimal('2.25')
request = factory.get('/?decimal=%s' % search_decimal) request = factory.get('/', {'decimal': '%s' % search_decimal})
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['decimal'] == search_decimal] expected_data = [f for f in self.data if f['decimal'] == search_decimal]
@ -184,7 +179,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
# Tests that the decimal filter set with 'lt' in the filter class works. # Tests that the decimal filter set with 'lt' in the filter class works.
search_decimal = Decimal('4.25') search_decimal = Decimal('4.25')
request = factory.get('/?decimal=%s' % search_decimal) request = factory.get('/', {'decimal': '%s' % search_decimal})
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['decimal'] < search_decimal] expected_data = [f for f in self.data if f['decimal'] < search_decimal]
@ -192,7 +187,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
# Tests that the date filter set with 'gt' in the filter class works. # Tests that the date filter set with 'gt' in the filter class works.
search_date = datetime.date(2012, 10, 2) search_date = datetime.date(2012, 10, 2)
request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02' request = factory.get('/', {'date': '%s' % search_date}) # search_date str: '2012-10-02'
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['date'] > search_date] expected_data = [f for f in self.data if f['date'] > search_date]
@ -200,7 +195,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
# Tests that the text filter set with 'icontains' in the filter class works. # Tests that the text filter set with 'icontains' in the filter class works.
search_text = 'ff' search_text = 'ff'
request = factory.get('/?text=%s' % search_text) request = factory.get('/', {'text': '%s' % search_text})
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if search_text in f['text'].lower()] expected_data = [f for f in self.data if search_text in f['text'].lower()]
@ -209,7 +204,10 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
# Tests that multiple filters works. # Tests that multiple filters works.
search_decimal = Decimal('5.25') search_decimal = Decimal('5.25')
search_date = datetime.date(2012, 10, 2) search_date = datetime.date(2012, 10, 2)
request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date)) request = factory.get('/', {
'decimal': '%s' % (search_decimal,),
'date': '%s' % (search_date,)
})
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['date'] > search_date and expected_data = [f for f in self.data if f['date'] > search_date and
@ -234,7 +232,7 @@ class IntegrationTestFiltering(CommonFilteringTestCase):
view = FilterFieldsRootView.as_view() view = FilterFieldsRootView.as_view()
search_integer = 10 search_integer = 10
request = factory.get('/?integer=%s' % search_integer) request = factory.get('/', {'integer': '%s' % search_integer})
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -265,14 +263,18 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase):
# Tests that the decimal filter set that should fail. # Tests that the decimal filter set that should fail.
search_decimal = Decimal('4.25') search_decimal = Decimal('4.25')
high_item = self.objects.filter(decimal__gt=search_decimal)[0] high_item = self.objects.filter(decimal__gt=search_decimal)[0]
response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(high_item), param=search_decimal)) response = self.client.get(
'{url}'.format(url=self._get_url(high_item)),
{'decimal': '{param}'.format(param=search_decimal)})
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
# Tests that the decimal filter set that should succeed. # Tests that the decimal filter set that should succeed.
search_decimal = Decimal('4.25') search_decimal = Decimal('4.25')
low_item = self.objects.filter(decimal__lt=search_decimal)[0] low_item = self.objects.filter(decimal__lt=search_decimal)[0]
low_item_data = self._serialize_object(low_item) low_item_data = self._serialize_object(low_item)
response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(low_item), param=search_decimal)) response = self.client.get(
'{url}'.format(url=self._get_url(low_item)),
{'decimal': '{param}'.format(param=search_decimal)})
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, low_item_data) self.assertEqual(response.data, low_item_data)
@ -281,7 +283,11 @@ class IntegrationTestDetailFiltering(CommonFilteringTestCase):
search_date = datetime.date(2012, 10, 2) search_date = datetime.date(2012, 10, 2)
valid_item = self.objects.filter(decimal__lt=search_decimal, date__gt=search_date)[0] valid_item = self.objects.filter(decimal__lt=search_decimal, date__gt=search_date)[0]
valid_item_data = self._serialize_object(valid_item) valid_item_data = self._serialize_object(valid_item)
response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date)) response = self.client.get(
'{url}'.format(url=self._get_url(valid_item)), {
'decimal': '{decimal}'.format(decimal=search_decimal),
'date': '{date}'.format(date=search_date)
})
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, valid_item_data) self.assertEqual(response.data, valid_item_data)
@ -315,7 +321,7 @@ class SearchFilterTests(TestCase):
search_fields = ('title', 'text') search_fields = ('title', 'text')
view = SearchListView.as_view() view = SearchListView.as_view()
request = factory.get('?search=b') request = factory.get('/', {'search': 'b'})
response = view(request) response = view(request)
self.assertEqual( self.assertEqual(
response.data, response.data,
@ -332,7 +338,7 @@ class SearchFilterTests(TestCase):
search_fields = ('=title', 'text') search_fields = ('=title', 'text')
view = SearchListView.as_view() view = SearchListView.as_view()
request = factory.get('?search=zzz') request = factory.get('/', {'search': 'zzz'})
response = view(request) response = view(request)
self.assertEqual( self.assertEqual(
response.data, response.data,
@ -348,7 +354,7 @@ class SearchFilterTests(TestCase):
search_fields = ('title', '^text') search_fields = ('title', '^text')
view = SearchListView.as_view() view = SearchListView.as_view()
request = factory.get('?search=b') request = factory.get('/', {'search': 'b'})
response = view(request) response = view(request)
self.assertEqual( self.assertEqual(
response.data, response.data,
@ -396,7 +402,7 @@ class OrderingFilterTests(TestCase):
ordering_fields = ('text',) ordering_fields = ('text',)
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('?ordering=text') request = factory.get('/', {'ordering': 'text'})
response = view(request) response = view(request)
self.assertEqual( self.assertEqual(
response.data, response.data,
@ -415,7 +421,7 @@ class OrderingFilterTests(TestCase):
ordering_fields = ('text',) ordering_fields = ('text',)
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('?ordering=-text') request = factory.get('/', {'ordering': '-text'})
response = view(request) response = view(request)
self.assertEqual( self.assertEqual(
response.data, response.data,
@ -434,7 +440,7 @@ class OrderingFilterTests(TestCase):
ordering_fields = ('text',) ordering_fields = ('text',)
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('?ordering=foobar') request = factory.get('/', {'ordering': 'foobar'})
response = view(request) response = view(request)
self.assertEqual( self.assertEqual(
response.data, response.data,
@ -503,7 +509,7 @@ class OrderingFilterTests(TestCase):
models.Count("relateds")) models.Count("relateds"))
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('?ordering=relateds__count') request = factory.get('/', {'ordering': 'relateds__count'})
response = view(request) response = view(request)
self.assertEqual( self.assertEqual(
response.data, response.data,
@ -566,7 +572,7 @@ class SensitiveOrderingFilterTests(TestCase):
serializer_class = serializer_cls serializer_class = serializer_cls
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('?ordering=-username') request = factory.get('/', {'ordering': '-username'})
response = view(request) response = view(request)
if serializer_cls == SensitiveDataSerializer3: if serializer_cls == SensitiveDataSerializer3:
@ -596,7 +602,7 @@ class SensitiveOrderingFilterTests(TestCase):
serializer_class = serializer_cls serializer_class = serializer_cls
view = OrderingListView.as_view() view = OrderingListView.as_view()
request = factory.get('?ordering=password') request = factory.get('/', {'ordering': 'password'})
response = view(request) response = view(request)
if serializer_cls == SensitiveDataSerializer3: if serializer_cls == SensitiveDataSerializer3:

View File

@ -4,8 +4,10 @@ from django.contrib.contenttypes.generic import GenericRelation, GenericForeignK
from django.db import models from django.db import models
from django.test import TestCase from django.test import TestCase
from rest_framework import serializers from rest_framework import serializers
from rest_framework.compat import python_2_unicode_compatible
@python_2_unicode_compatible
class Tag(models.Model): class Tag(models.Model):
""" """
Tags have a descriptive slug, and are attached to an arbitrary object. Tags have a descriptive slug, and are attached to an arbitrary object.
@ -15,10 +17,11 @@ class Tag(models.Model):
object_id = models.PositiveIntegerField() object_id = models.PositiveIntegerField()
tagged_item = GenericForeignKey('content_type', 'object_id') tagged_item = GenericForeignKey('content_type', 'object_id')
def __unicode__(self): def __str__(self):
return self.tag return self.tag
@python_2_unicode_compatible
class Bookmark(models.Model): class Bookmark(models.Model):
""" """
A URL bookmark that may have multiple tags attached. A URL bookmark that may have multiple tags attached.
@ -26,10 +29,11 @@ class Bookmark(models.Model):
url = models.URLField() url = models.URLField()
tags = GenericRelation(Tag) tags = GenericRelation(Tag)
def __unicode__(self): def __str__(self):
return 'Bookmark: %s' % self.url return 'Bookmark: %s' % self.url
@python_2_unicode_compatible
class Note(models.Model): class Note(models.Model):
""" """
A textual note that may have multiple tags attached. A textual note that may have multiple tags attached.
@ -37,7 +41,7 @@ class Note(models.Model):
text = models.TextField() text = models.TextField()
tags = GenericRelation(Tag) tags = GenericRelation(Tag)
def __unicode__(self): def __str__(self):
return 'Note: %s' % self.text return 'Note: %s' % self.text

View File

@ -50,7 +50,7 @@ class TemplateHTMLRendererTests(TestCase):
""" """
self.get_template = django.template.loader.get_template self.get_template = django.template.loader.get_template
def get_template(template_name): def get_template(template_name, dirs=None):
if template_name == 'example.html': if template_name == 'example.html':
return Template("example: {{ object }}") return Template("example: {{ object }}")
raise TemplateDoesNotExist(template_name) raise TemplateDoesNotExist(template_name)
@ -108,11 +108,13 @@ class TemplateHTMLRendererExceptionTests(TestCase):
def test_not_found_html_view_with_template(self): def test_not_found_html_view_with_template(self):
response = self.client.get('/not_found') response = self.client.get('/not_found')
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertEqual(response.content, six.b("404: Not found")) self.assertTrue(response.content in (
six.b("404: Not found"), six.b("404 Not Found")))
self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
def test_permission_denied_html_view_with_template(self): def test_permission_denied_html_view_with_template(self):
response = self.client.get('/permission_denied') response = self.client.get('/permission_denied')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(response.content, six.b("403: Permission denied")) self.assertTrue(response.content in (
six.b("403: Permission denied"), six.b("403 Forbidden")))
self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8') self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')

View File

@ -9,14 +9,18 @@ from rest_framework import generics, status, pagination, filters, serializers
from rest_framework.compat import django_filters from rest_framework.compat import django_filters
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
from rest_framework.tests.models import BasicModel from rest_framework.tests.models import BasicModel
from .models import FilterableItem
factory = APIRequestFactory() factory = APIRequestFactory()
# Helper function to split arguments out of an url
def split_arguments_from_url(url):
if '?' not in url:
return url
class FilterableItem(models.Model): path, args = url.split('?')
text = models.CharField(max_length=100) args = dict(r.split('=') for r in args.split('&'))
decimal = models.DecimalField(max_digits=4, decimal_places=2) return path, args
date = models.DateField()
class RootView(generics.ListCreateAPIView): class RootView(generics.ListCreateAPIView):
@ -84,7 +88,7 @@ class IntegrationTestPagination(TestCase):
self.assertNotEqual(response.data['next'], None) self.assertNotEqual(response.data['next'], None)
self.assertEqual(response.data['previous'], None) self.assertEqual(response.data['previous'], None)
request = factory.get(response.data['next']) request = factory.get(*split_arguments_from_url(response.data['next']))
with self.assertNumQueries(2): with self.assertNumQueries(2):
response = self.view(request).render() response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -93,7 +97,7 @@ class IntegrationTestPagination(TestCase):
self.assertNotEqual(response.data['next'], None) self.assertNotEqual(response.data['next'], None)
self.assertNotEqual(response.data['previous'], None) self.assertNotEqual(response.data['previous'], None)
request = factory.get(response.data['next']) request = factory.get(*split_arguments_from_url(response.data['next']))
with self.assertNumQueries(2): with self.assertNumQueries(2):
response = self.view(request).render() response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -146,7 +150,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
EXPECTED_NUM_QUERIES = 2 EXPECTED_NUM_QUERIES = 2
request = factory.get('/?decimal=15.20') request = factory.get('/', {'decimal': '15.20'})
with self.assertNumQueries(EXPECTED_NUM_QUERIES): with self.assertNumQueries(EXPECTED_NUM_QUERIES):
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -155,7 +159,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
self.assertNotEqual(response.data['next'], None) self.assertNotEqual(response.data['next'], None)
self.assertEqual(response.data['previous'], None) self.assertEqual(response.data['previous'], None)
request = factory.get(response.data['next']) request = factory.get(*split_arguments_from_url(response.data['next']))
with self.assertNumQueries(EXPECTED_NUM_QUERIES): with self.assertNumQueries(EXPECTED_NUM_QUERIES):
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -164,7 +168,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
self.assertEqual(response.data['next'], None) self.assertEqual(response.data['next'], None)
self.assertNotEqual(response.data['previous'], None) self.assertNotEqual(response.data['previous'], None)
request = factory.get(response.data['previous']) request = factory.get(*split_arguments_from_url(response.data['previous']))
with self.assertNumQueries(EXPECTED_NUM_QUERIES): with self.assertNumQueries(EXPECTED_NUM_QUERIES):
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -191,7 +195,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
view = BasicFilterFieldsRootView.as_view() view = BasicFilterFieldsRootView.as_view()
request = factory.get('/?decimal=15.20') request = factory.get('/', {'decimal': '15.20'})
with self.assertNumQueries(2): with self.assertNumQueries(2):
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -200,7 +204,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
self.assertNotEqual(response.data['next'], None) self.assertNotEqual(response.data['next'], None)
self.assertEqual(response.data['previous'], None) self.assertEqual(response.data['previous'], None)
request = factory.get(response.data['next']) request = factory.get(*split_arguments_from_url(response.data['next']))
with self.assertNumQueries(2): with self.assertNumQueries(2):
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -209,7 +213,7 @@ class IntegrationTestPaginationAndFiltering(TestCase):
self.assertEqual(response.data['next'], None) self.assertEqual(response.data['next'], None)
self.assertNotEqual(response.data['previous'], None) self.assertNotEqual(response.data['previous'], None)
request = factory.get(response.data['previous']) request = factory.get(*split_arguments_from_url(response.data['previous']))
with self.assertNumQueries(2): with self.assertNumQueries(2):
response = view(request).render() response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -317,7 +321,7 @@ class TestCustomPaginateByParam(TestCase):
""" """
If paginate_by_param is set, the new kwarg should limit per view requests. If paginate_by_param is set, the new kwarg should limit per view requests.
""" """
request = factory.get('/?page_size=5') request = factory.get('/', {'page_size': 5})
response = self.view(request).render() response = self.view(request).render()
self.assertEqual(response.data['count'], 13) self.assertEqual(response.data['count'], 13)
self.assertEqual(response.data['results'], self.data[:5]) self.assertEqual(response.data['results'], self.data[:5])
@ -345,7 +349,7 @@ class TestMaxPaginateByParam(TestCase):
""" """
If max_paginate_by is set, it should limit page size for the view. If max_paginate_by is set, it should limit page size for the view.
""" """
request = factory.get('/?page_size=10') request = factory.get('/', data={'page_size': 10})
response = self.view(request).render() response = self.view(request).render()
self.assertEqual(response.data['count'], 13) self.assertEqual(response.data['count'], 13)
self.assertEqual(response.data['results'], self.data[:5]) self.assertEqual(response.data['results'], self.data[:5])

View File

@ -3,9 +3,7 @@ from django.db import models
from django.test import TestCase from django.test import TestCase
from rest_framework import serializers from rest_framework import serializers
from .models import OneToOneTarget
class OneToOneTarget(models.Model):
name = models.CharField(max_length=100)
class OneToOneSource(models.Model): class OneToOneSource(models.Model):

View File

@ -613,6 +613,10 @@ class CacheRenderTest(TestCase):
method = getattr(self.client, http_method) method = getattr(self.client, http_method)
resp = method(url) resp = method(url)
del resp.client, resp.request del resp.client, resp.request
try:
del resp.wsgi_request
except AttributeError:
pass
return resp return resp
def test_obj_pickling(self): def test_obj_pickling(self):

View File

@ -14,6 +14,26 @@ import datetime
import pickle import pickle
class AMOAFModel(RESTFrameworkModel):
char_field = models.CharField(max_length=1024, blank=True)
comma_separated_integer_field = models.CommaSeparatedIntegerField(max_length=1024, blank=True)
decimal_field = models.DecimalField(max_digits=64, decimal_places=32, blank=True)
email_field = models.EmailField(max_length=1024, blank=True)
file_field = models.FileField(upload_to='test', max_length=1024, blank=True)
image_field = models.ImageField(upload_to='test', max_length=1024, blank=True)
slug_field = models.SlugField(max_length=1024, blank=True)
url_field = models.URLField(max_length=1024, blank=True)
class DVOAFModel(RESTFrameworkModel):
positive_integer_field = models.PositiveIntegerField(blank=True)
positive_small_integer_field = models.PositiveSmallIntegerField(blank=True)
email_field = models.EmailField(blank=True)
file_field = models.FileField(upload_to='test', blank=True)
image_field = models.ImageField(upload_to='test', blank=True)
slug_field = models.SlugField(blank=True)
url_field = models.URLField(blank=True)
class SubComment(object): class SubComment(object):
def __init__(self, sub_comment): def __init__(self, sub_comment):
self.sub_comment = sub_comment self.sub_comment = sub_comment
@ -1496,15 +1516,6 @@ class ManyFieldHelpTextTest(TestCase):
class AttributeMappingOnAutogeneratedFieldsTests(TestCase): class AttributeMappingOnAutogeneratedFieldsTests(TestCase):
def setUp(self): def setUp(self):
class AMOAFModel(RESTFrameworkModel):
char_field = models.CharField(max_length=1024, blank=True)
comma_separated_integer_field = models.CommaSeparatedIntegerField(max_length=1024, blank=True)
decimal_field = models.DecimalField(max_digits=64, decimal_places=32, blank=True)
email_field = models.EmailField(max_length=1024, blank=True)
file_field = models.FileField(max_length=1024, blank=True)
image_field = models.ImageField(max_length=1024, blank=True)
slug_field = models.SlugField(max_length=1024, blank=True)
url_field = models.URLField(max_length=1024, blank=True)
class AMOAFSerializer(serializers.ModelSerializer): class AMOAFSerializer(serializers.ModelSerializer):
class Meta: class Meta:
@ -1577,14 +1588,6 @@ class AttributeMappingOnAutogeneratedFieldsTests(TestCase):
class DefaultValuesOnAutogeneratedFieldsTests(TestCase): class DefaultValuesOnAutogeneratedFieldsTests(TestCase):
def setUp(self): def setUp(self):
class DVOAFModel(RESTFrameworkModel):
positive_integer_field = models.PositiveIntegerField(blank=True)
positive_small_integer_field = models.PositiveSmallIntegerField(blank=True)
email_field = models.EmailField(blank=True)
file_field = models.FileField(blank=True)
image_field = models.ImageField(blank=True)
slug_field = models.SlugField(blank=True)
url_field = models.URLField(blank=True)
class DVOAFSerializer(serializers.ModelSerializer): class DVOAFSerializer(serializers.ModelSerializer):
class Meta: class Meta: