Merge branch 'p3k' of github.com:linovia/django-rest-framework into p3k

This commit is contained in:
Xavier Ordoquy 2012-11-22 07:48:41 +01:00
commit be003145ca
23 changed files with 170 additions and 120 deletions

View File

@ -12,6 +12,7 @@ env:
install: install:
- pip install $DJANGO - pip install $DJANGO
- pip install django-filter==0.5.4 --use-mirrors - pip install django-filter==0.5.4 --use-mirrors
- pip install six --use-mirrors
- export PYTHONPATH=. - export PYTHONPATH=.
script: script:

View File

@ -1 +1,2 @@
Django>=1.3 Django>=1.3
six

View File

@ -3,7 +3,11 @@ Provides a set of pluggable authentication policies.
""" """
from django.contrib.auth import authenticate from django.contrib.auth import authenticate
from django.utils.encoding import smart_unicode, DjangoUnicodeDecodeError from django.utils.encoding import DjangoUnicodeDecodeError
try:
from django.utils.encoding import smart_text
except ImportError:
from django.utils.encoding import smart_unicode as smart_text
from rest_framework import exceptions from rest_framework import exceptions
from rest_framework.compat import CsrfViewMiddleware from rest_framework.compat import CsrfViewMiddleware
from rest_framework.authtoken.models import Token from rest_framework.authtoken.models import Token
@ -36,13 +40,13 @@ class BasicAuthentication(BaseAuthentication):
auth = request.META['HTTP_AUTHORIZATION'].split() auth = request.META['HTTP_AUTHORIZATION'].split()
if len(auth) == 2 and auth[0].lower() == "basic": if len(auth) == 2 and auth[0].lower() == "basic":
try: try:
auth_parts = base64.b64decode(auth[1]).partition(':') auth_parts = base64.b64decode(auth[1].encode('utf8')).decode('utf8').partition(':')
except TypeError: except TypeError:
return None return None
try: try:
userid = smart_unicode(auth_parts[0]) userid = smart_text(auth_parts[0])
password = smart_unicode(auth_parts[2]) password = smart_text(auth_parts[2])
except DjangoUnicodeDecodeError: except DjangoUnicodeDecodeError:
return None return None

View File

@ -3,6 +3,9 @@ The `compat` module provides support for backwards compatibility with older
versions of django/python, and compatibility wrappers around optional packages. versions of django/python, and compatibility wrappers around optional packages.
""" """
# flake8: noqa # flake8: noqa
from __future__ import unicode_literals
import six
import django import django
# django-filter is optional # django-filter is optional
@ -14,9 +17,9 @@ except:
# cStringIO only if it's available, otherwise StringIO # cStringIO only if it's available, otherwise StringIO
try: try:
import cStringIO as StringIO import cStringIO.StringIO as StringIO
except ImportError: except ImportError:
import StringIO from six import StringIO
def get_concrete_model(model_cls): def get_concrete_model(model_cls):
@ -38,7 +41,7 @@ else:
try: try:
from django.contrib.auth.models import User from django.contrib.auth.models import User
except ImportError: except ImportError:
raise ImportError(u"User model is not to be found.") raise ImportError("User model is not to be found.")
# First implementation of Django class-based views did not include head method # First implementation of Django class-based views did not include head method
@ -59,11 +62,11 @@ else:
# sanitize keyword arguments # sanitize keyword arguments
for key in initkwargs: for key in initkwargs:
if key in cls.http_method_names: if key in cls.http_method_names:
raise TypeError(u"You tried to pass in the %s method name as a " raise TypeError("You tried to pass in the %s method name as a "
u"keyword argument to %s(). Don't do that." "keyword argument to %s(). Don't do that."
% (key, cls.__name__)) % (key, cls.__name__))
if not hasattr(cls, key): if not hasattr(cls, key):
raise TypeError(u"%s() received an invalid keyword %r" % ( raise TypeError("%s() received an invalid keyword %r" % (
cls.__name__, key)) cls.__name__, key))
def view(request, *args, **kwargs): def view(request, *args, **kwargs):
@ -130,7 +133,8 @@ else:
randrange = random.SystemRandom().randrange randrange = random.SystemRandom().randrange
else: else:
randrange = random.randrange randrange = random.randrange
_MAX_CSRF_KEY = 18446744073709551616L # 2 << 63
_MAX_CSRF_KEY = 18446744073709551616 # 2 << 63
REASON_NO_REFERER = "Referer checking failed - no Referer." REASON_NO_REFERER = "Referer checking failed - no Referer."
REASON_BAD_REFERER = "Referer checking failed - %s does not match %s." REASON_BAD_REFERER = "Referer checking failed - %s does not match %s."

View File

@ -1,3 +1,7 @@
from __future__ import unicode_literals
import six
import copy import copy
import datetime import datetime
import inspect import inspect
@ -12,12 +16,19 @@ from django.core.urlresolvers import resolve, get_script_prefix
from django.conf import settings from django.conf import settings
from django.forms import widgets from django.forms import widgets
from django.forms.models import ModelChoiceIterator from django.forms.models import ModelChoiceIterator
from django.utils.encoding import is_protected_type, smart_unicode from django.utils.encoding import is_protected_type
try:
from django.utils.encoding import smart_text
except ImportError:
from django.utils.encoding import smart_unicode as smart_text
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework.reverse import reverse from rest_framework.reverse import reverse
from rest_framework.compat import parse_date, parse_datetime from rest_framework.compat import parse_date, parse_datetime
from rest_framework.compat import timezone from rest_framework.compat import timezone
from urlparse import urlparse try:
from urllib.parse import urlparse
except ImportError:
from urlparse import urlparse
def is_simple_callable(obj): def is_simple_callable(obj):
@ -92,11 +103,11 @@ class Field(object):
if is_protected_type(value): if is_protected_type(value):
return value return value
elif hasattr(value, '__iter__') and not isinstance(value, (dict, basestring)): elif hasattr(value, '__iter__') and not isinstance(value, (dict, six.string_types)):
return [self.to_native(item) for item in value] return [self.to_native(item) for item in value]
elif isinstance(value, dict): elif isinstance(value, dict):
return dict(map(self.to_native, (k, v)) for k, v in value.items()) return dict(map(self.to_native, (k, v)) for k, v in value.items())
return smart_unicode(value) return smart_text(value)
def attributes(self): def attributes(self):
""" """
@ -297,8 +308,8 @@ class RelatedField(WritableField):
""" """
Return a readable representation for use with eg. select widgets. Return a readable representation for use with eg. select widgets.
""" """
desc = smart_unicode(obj) desc = smart_text(obj)
ident = smart_unicode(self.to_native(obj)) ident = smart_text(self.to_native(obj))
if desc == ident: if desc == ident:
return desc return desc
return "%s - %s" % (desc, ident) return "%s - %s" % (desc, ident)
@ -401,8 +412,8 @@ class PrimaryKeyRelatedField(RelatedField):
""" """
Return a readable representation for use with eg. select widgets. Return a readable representation for use with eg. select widgets.
""" """
desc = smart_unicode(obj) desc = smart_text(obj)
ident = smart_unicode(self.to_native(obj.pk)) ident = smart_text(self.to_native(obj.pk))
if desc == ident: if desc == ident:
return desc return desc
return "%s - %s" % (desc, ident) return "%s - %s" % (desc, ident)
@ -418,7 +429,7 @@ class PrimaryKeyRelatedField(RelatedField):
try: try:
return self.queryset.get(pk=data) return self.queryset.get(pk=data)
except ObjectDoesNotExist: except ObjectDoesNotExist:
msg = "Invalid pk '%s' - object does not exist." % smart_unicode(data) msg = "Invalid pk '%s' - object does not exist." % smart_text(data)
raise ValidationError(msg) raise ValidationError(msg)
def field_to_native(self, obj, field_name): def field_to_native(self, obj, field_name):
@ -446,8 +457,8 @@ class ManyPrimaryKeyRelatedField(ManyRelatedField):
""" """
Return a readable representation for use with eg. select widgets. Return a readable representation for use with eg. select widgets.
""" """
desc = smart_unicode(obj) desc = smart_text(obj)
ident = smart_unicode(self.to_native(obj.pk)) ident = smart_text(self.to_native(obj.pk))
if desc == ident: if desc == ident:
return desc return desc
return "%s - %s" % (desc, ident) return "%s - %s" % (desc, ident)
@ -473,7 +484,7 @@ class ManyPrimaryKeyRelatedField(ManyRelatedField):
try: try:
return self.queryset.get(pk=data) return self.queryset.get(pk=data)
except ObjectDoesNotExist: except ObjectDoesNotExist:
msg = "Invalid pk '%s' - object does not exist." % smart_unicode(data) msg = "Invalid pk '%s' - object does not exist." % smart_text(data)
raise ValidationError(msg) raise ValidationError(msg)
### Slug relationships ### Slug relationships
@ -674,7 +685,7 @@ class BooleanField(WritableField):
type_name = 'BooleanField' type_name = 'BooleanField'
widget = widgets.CheckboxInput widget = widgets.CheckboxInput
default_error_messages = { default_error_messages = {
'invalid': _(u"'%s' value must be either True or False."), 'invalid': _("'%s' value must be either True or False."),
} }
empty = False empty = False
@ -713,9 +724,9 @@ class CharField(WritableField):
super(CharField, self).validate(value) super(CharField, self).validate(value)
def from_native(self, value): def from_native(self, value):
if isinstance(value, basestring) or value is None: if isinstance(value, six.string_types) or value is None:
return value return value
return smart_unicode(value) return smart_text(value)
class URLField(CharField): class URLField(CharField):
@ -773,10 +784,10 @@ class ChoiceField(WritableField):
if isinstance(v, (list, tuple)): if isinstance(v, (list, tuple)):
# This is an optgroup, so look inside the group for options # This is an optgroup, so look inside the group for options
for k2, v2 in v: for k2, v2 in v:
if value == smart_unicode(k2): if value == smart_text(k2):
return True return True
else: else:
if value == smart_unicode(k): if value == smart_text(k):
return True return True
return False return False
@ -814,7 +825,7 @@ class RegexField(CharField):
return self._regex return self._regex
def _set_regex(self, regex): def _set_regex(self, regex):
if isinstance(regex, basestring): if isinstance(regex, six.string_types):
regex = re.compile(regex) regex = re.compile(regex)
self._regex = regex self._regex = regex
if hasattr(self, '_regex_validator') and self._regex_validator in self.validators: if hasattr(self, '_regex_validator') and self._regex_validator in self.validators:
@ -835,10 +846,10 @@ class DateField(WritableField):
type_name = 'DateField' type_name = 'DateField'
default_error_messages = { default_error_messages = {
'invalid': _(u"'%s' value has an invalid date format. It must be " 'invalid': _("'%s' value has an invalid date format. It must be "
u"in YYYY-MM-DD format."), "in YYYY-MM-DD format."),
'invalid_date': _(u"'%s' value has the correct format (YYYY-MM-DD) " 'invalid_date': _("'%s' value has the correct format (YYYY-MM-DD) "
u"but it is an invalid date."), "but it is an invalid date."),
} }
empty = None empty = None
@ -872,13 +883,13 @@ class DateTimeField(WritableField):
type_name = 'DateTimeField' type_name = 'DateTimeField'
default_error_messages = { default_error_messages = {
'invalid': _(u"'%s' value has an invalid format. It must be in " 'invalid': _("'%s' value has an invalid format. It must be in "
u"YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ] format."), "YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ] format."),
'invalid_date': _(u"'%s' value has the correct format " 'invalid_date': _("'%s' value has the correct format "
u"(YYYY-MM-DD) but it is an invalid date."), "(YYYY-MM-DD) but it is an invalid date."),
'invalid_datetime': _(u"'%s' value has the correct format " 'invalid_datetime': _("'%s' value has the correct format "
u"(YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ]) " "(YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ]) "
u"but it is an invalid date/time."), "but it is an invalid date/time."),
} }
empty = None empty = None
@ -895,8 +906,8 @@ class DateTimeField(WritableField):
# local time. This won't work during DST change, but we can't # local time. This won't work during DST change, but we can't
# do much about it, so we let the exceptions percolate up the # do much about it, so we let the exceptions percolate up the
# call stack. # call stack.
warnings.warn(u"DateTimeField received a naive datetime (%s)" warnings.warn("DateTimeField received a naive datetime (%s)"
u" while time zone support is active." % value, " while time zone support is active." % value,
RuntimeWarning) RuntimeWarning)
default_timezone = timezone.get_default_timezone() default_timezone = timezone.get_default_timezone()
value = timezone.make_aware(value, default_timezone) value = timezone.make_aware(value, default_timezone)

View File

@ -4,6 +4,8 @@ Basic building blocks for generic class based views.
We don't bind behaviour to http method handlers yet, We don't bind behaviour to http method handlers yet,
which allows mixin classes to be composed in interesting ways. which allows mixin classes to be composed in interesting ways.
""" """
from __future__ import unicode_literals
from django.http import Http404 from django.http import Http404
from rest_framework import status from rest_framework import status
from rest_framework.response import Response from rest_framework.response import Response
@ -38,7 +40,7 @@ class ListModelMixin(object):
List a queryset. List a queryset.
Should be mixed in with `MultipleObjectAPIView`. Should be mixed in with `MultipleObjectAPIView`.
""" """
empty_error = u"Empty list and '%(class_name)s.allow_empty' is False." empty_error = "Empty list and '%(class_name)s.allow_empty' is False."
def list(self, request, *args, **kwargs): def list(self, request, *args, **kwargs):
queryset = self.get_queryset() queryset = self.get_queryset()

View File

@ -56,7 +56,7 @@ class JSONParser(BaseParser):
""" """
try: try:
return json.load(stream) return json.load(stream)
except ValueError, exc: except ValueError as exc:
raise ParseError('JSON parse error - %s' % unicode(exc)) raise ParseError('JSON parse error - %s' % unicode(exc))
@ -76,7 +76,7 @@ class YAMLParser(BaseParser):
""" """
try: try:
return yaml.safe_load(stream) return yaml.safe_load(stream)
except (ValueError, yaml.parser.ParserError), exc: except (ValueError, yaml.parser.ParserError) as exc:
raise ParseError('YAML parse error - %s' % unicode(exc)) raise ParseError('YAML parse error - %s' % unicode(exc))
@ -121,7 +121,7 @@ class MultiPartParser(BaseParser):
parser = DjangoMultiPartParser(meta, stream, upload_handlers) parser = DjangoMultiPartParser(meta, stream, upload_handlers)
data, files = parser.parse() data, files = parser.parse()
return DataAndFiles(data, files) return DataAndFiles(data, files)
except MultiPartParserError, exc: except MultiPartParserError as exc:
raise ParseError('Multipart form parse error - %s' % unicode(exc)) raise ParseError('Multipart form parse error - %s' % unicode(exc))
@ -135,7 +135,7 @@ class XMLParser(BaseParser):
def parse(self, stream, media_type=None, parser_context=None): def parse(self, stream, media_type=None, parser_context=None):
try: try:
tree = ET.parse(stream) tree = ET.parse(stream)
except (ExpatError, ETParseError, ValueError), exc: except (ExpatError, ETParseError, ValueError) as exc:
raise ParseError('XML parse error - %s' % unicode(exc)) raise ParseError('XML parse error - %s' % unicode(exc))
data = self._xml_convert(tree.getroot()) data = self._xml_convert(tree.getroot())

View File

@ -6,6 +6,8 @@ on the response, such as JSON encoded data or HTML output.
REST framework also provides an HTML renderer the renders the browsable API. REST framework also provides an HTML renderer the renders the browsable API.
""" """
from __future__ import unicode_literals
import copy import copy
import string import string
from django import forms from django import forms
@ -60,7 +62,7 @@ class JSONRenderer(BaseRenderer):
if accepted_media_type: if accepted_media_type:
# If the media type looks like 'application/json; indent=4', # If the media type looks like 'application/json; indent=4',
# then pretty print the result. # then pretty print the result.
base_media_type, params = parse_header(accepted_media_type) base_media_type, params = parse_header(accepted_media_type.encode('ascii'))
indent = params.get('indent', indent) indent = params.get('indent', indent)
try: try:
indent = max(min(int(indent), 8), 0) indent = max(min(int(indent), 8), 0)
@ -100,7 +102,7 @@ class JSONPRenderer(JSONRenderer):
callback = self.get_callback(renderer_context) callback = self.get_callback(renderer_context)
json = super(JSONPRenderer, self).render(data, accepted_media_type, json = super(JSONPRenderer, self).render(data, accepted_media_type,
renderer_context) renderer_context)
return u"%s(%s);" % (callback, json) return "%s(%s);" % (callback, json)
class XMLRenderer(BaseRenderer): class XMLRenderer(BaseRenderer):

View File

@ -9,7 +9,7 @@ The wrapped request then offers a richer API, in particular :
- full support of PUT method, including support for file uploads - full support of PUT method, including support for file uploads
- form overloading of HTTP method, content type and content - form overloading of HTTP method, content type and content
""" """
from StringIO import StringIO from rest_framework.compat import StringIO
from django.http.multipartparser import parse_header from django.http.multipartparser import parse_header
from rest_framework import exceptions from rest_framework import exceptions
@ -20,7 +20,7 @@ def is_form_media_type(media_type):
""" """
Return True if the media type is a valid form media type. Return True if the media type is a valid form media type.
""" """
base_media_type, params = parse_header(media_type) base_media_type, params = parse_header(media_type.encode('utf8'))
return (base_media_type == 'application/x-www-form-urlencoded' or return (base_media_type == 'application/x-www-form-urlencoded' or
base_media_type == 'multipart/form-data') base_media_type == 'multipart/form-data')

View File

@ -19,6 +19,8 @@ back to the defaults.
""" """
from django.conf import settings from django.conf import settings
from django.utils import importlib from django.utils import importlib
from six import string_types
USER_SETTINGS = getattr(settings, 'REST_FRAMEWORK', None) USER_SETTINGS = getattr(settings, 'REST_FRAMEWORK', None)
@ -98,7 +100,7 @@ def perform_import(val, setting_name):
If the given setting is a string import notation, If the given setting is a string import notation,
then perform the necessary import or imports. then perform the necessary import or imports.
""" """
if isinstance(val, basestring): if isinstance(val, string_types):
return import_from_string(val, setting_name) return import_from_string(val, setting_name)
elif isinstance(val, (list, tuple)): elif isinstance(val, (list, tuple)):
return [import_from_string(item, setting_name) for item in val] return [import_from_string(item, setting_name) for item in val]

View File

@ -1,10 +1,18 @@
from __future__ import unicode_literals
from django import template from django import template
from django.core.urlresolvers import reverse from django.core.urlresolvers import reverse
from django.http import QueryDict from django.http import QueryDict
from django.utils.encoding import force_unicode try:
from django.utils.encoding import force_text
except ImportError:
from django.utils.encoding import force_unicode as force_text
from django.utils.html import escape from django.utils.html import escape
from django.utils.safestring import SafeData, mark_safe from django.utils.safestring import SafeData, mark_safe
from urlparse import urlsplit, urlunsplit try:
from urllib.parse import urlsplit, urlunsplit
except ImportError:
from urlparse import urlsplit, urlunsplit
import re import re
import string import string
@ -130,7 +138,7 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
""" """
trim_url = lambda x, limit=trim_url_limit: limit is not None and (len(x) > limit and ('%s...' % x[:max(0, limit - 3)])) or x trim_url = lambda x, limit=trim_url_limit: limit is not None and (len(x) > limit and ('%s...' % x[:max(0, limit - 3)])) or x
safe_input = isinstance(text, SafeData) safe_input = isinstance(text, SafeData)
words = word_split_re.split(force_unicode(text)) words = word_split_re.split(force_text(text))
nofollow_attr = nofollow and ' rel="nofollow"' or '' nofollow_attr = nofollow and ' rel="nofollow"' or ''
for i, word in enumerate(words): for i, word in enumerate(words):
match = None match = None
@ -166,4 +174,4 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
words[i] = mark_safe(word) words[i] = mark_safe(word)
elif autoescape: elif autoescape:
words[i] = escape(word) words[i] = escape(word)
return mark_safe(u''.join(words)) return mark_safe(''.join(words))

View File

@ -44,13 +44,13 @@ class BasicAuthTests(TestCase):
def test_post_form_passing_basic_auth(self): def test_post_form_passing_basic_auth(self):
"""Ensure POSTing json over basic auth with correct credentials passes and does not require CSRF""" """Ensure POSTing json over basic auth with correct credentials passes and does not require CSRF"""
auth = 'Basic %s' % base64.encodestring('%s:%s' % (self.username, self.password)).strip() auth = 'Basic ' + base64.encodestring(('%s:%s' % (self.username, self.password)).encode('utf8')).strip().decode('utf8')
response = self.csrf_client.post('/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) response = self.csrf_client.post('/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
def test_post_json_passing_basic_auth(self): def test_post_json_passing_basic_auth(self):
"""Ensure POSTing form over basic auth with correct credentials passes and does not require CSRF""" """Ensure POSTing form over basic auth with correct credentials passes and does not require CSRF"""
auth = 'Basic %s' % base64.encodestring('%s:%s' % (self.username, self.password)).strip() auth = 'Basic ' + base64.encodestring(('%s:%s' % (self.username, self.password)).encode('utf8')).strip().decode('utf8')
response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) response = self.csrf_client.post('/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)

View File

@ -1,4 +1,5 @@
import StringIO from rest_framework.compat import StringIO
import datetime import datetime
from django.test import TestCase from django.test import TestCase
@ -28,7 +29,7 @@ class FileSerializerTests(TestCase):
def test_create(self): def test_create(self):
now = datetime.datetime.now() now = datetime.datetime.now()
file = StringIO.StringIO('stuff') file = StringIO('stuff')
file.name = 'stuff.txt' file.name = 'stuff.txt'
file.size = file.len file.size = file.len
serializer = UploadedFileSerializer(data={'created': now}, files={'file': file}) serializer = UploadedFileSerializer(data={'created': now}, files={'file': file})

View File

@ -1,3 +1,5 @@
from __future__ import unicode_literals
from django.test import TestCase from django.test import TestCase
from rest_framework import serializers from rest_framework import serializers
from rest_framework.tests.models import * from rest_framework.tests.models import *
@ -27,7 +29,7 @@ class TestGenericRelations(TestCase):
serializer = BookmarkSerializer(self.bookmark) serializer = BookmarkSerializer(self.bookmark)
expected = { expected = {
'tags': [u'django', u'python'], 'tags': ['django', 'python'],
'url': u'https://www.djangoproject.com/' 'url': 'https://www.djangoproject.com/'
} }
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)

View File

@ -1,3 +1,5 @@
from __future__ import unicode_literals
from django.test import TestCase from django.test import TestCase
from django.test.client import RequestFactory from django.test.client import RequestFactory
from django.utils import simplejson as json from django.utils import simplejson as json
@ -71,7 +73,7 @@ class TestRootView(TestCase):
content_type='application/json') content_type='application/json')
response = self.view(request).render() response = self.view(request).render()
self.assertEquals(response.status_code, status.HTTP_201_CREATED) self.assertEquals(response.status_code, status.HTTP_201_CREATED)
self.assertEquals(response.data, {'id': 4, 'text': u'foobar'}) self.assertEquals(response.data, {'id': 4, 'text': 'foobar'})
created = self.objects.get(id=4) created = self.objects.get(id=4)
self.assertEquals(created.text, 'foobar') self.assertEquals(created.text, 'foobar')
@ -126,7 +128,7 @@ class TestRootView(TestCase):
content_type='application/json') content_type='application/json')
response = self.view(request).render() response = self.view(request).render()
self.assertEquals(response.status_code, status.HTTP_201_CREATED) self.assertEquals(response.status_code, status.HTTP_201_CREATED)
self.assertEquals(response.data, {'id': 4, 'text': u'foobar'}) self.assertEquals(response.data, {'id': 4, 'text': 'foobar'})
created = self.objects.get(id=4) created = self.objects.get(id=4)
self.assertEquals(created.text, 'foobar') self.assertEquals(created.text, 'foobar')

View File

@ -131,7 +131,7 @@
# self.assertEqual(data['key1'], 'val1') # self.assertEqual(data['key1'], 'val1')
# self.assertEqual(files['file1'].read(), 'blablabla') # self.assertEqual(files['file1'].read(), 'blablabla')
from StringIO import StringIO from rest_framework.compat import StringIO
from django import forms from django import forms
from django.test import TestCase from django.test import TestCase
from rest_framework.parsers import FormParser from rest_framework.parsers import FormParser

View File

@ -1,3 +1,5 @@
from __future__ import unicode_literals
from django.db import models from django.db import models
from django.test import TestCase from django.test import TestCase
from rest_framework import serializers from rest_framework import serializers
@ -65,9 +67,9 @@ class PrimaryKeyManyToManyTests(TestCase):
queryset = ManyToManySource.objects.all() queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset) serializer = ManyToManySourceSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'source-1', 'targets': [1]}, {'id': 1, 'name': 'source-1', 'targets': [1]},
{'id': 2, 'name': u'source-2', 'targets': [1, 2]}, {'id': 2, 'name': 'source-2', 'targets': [1, 2]},
{'id': 3, 'name': u'source-3', 'targets': [1, 2, 3]} {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -75,14 +77,14 @@ class PrimaryKeyManyToManyTests(TestCase):
queryset = ManyToManyTarget.objects.all() queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset) serializer = ManyToManyTargetSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]}, {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
{'id': 2, 'name': u'target-2', 'sources': [2, 3]}, {'id': 2, 'name': 'target-2', 'sources': [2, 3]},
{'id': 3, 'name': u'target-3', 'sources': [3]} {'id': 3, 'name': 'target-3', 'sources': [3]}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_many_to_many_update(self): def test_many_to_many_update(self):
data = {'id': 1, 'name': u'source-1', 'targets': [1, 2, 3]} data = {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]}
instance = ManyToManySource.objects.get(pk=1) instance = ManyToManySource.objects.get(pk=1)
serializer = ManyToManySourceSerializer(instance, data=data) serializer = ManyToManySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
@ -93,14 +95,14 @@ class PrimaryKeyManyToManyTests(TestCase):
queryset = ManyToManySource.objects.all() queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset) serializer = ManyToManySourceSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'source-1', 'targets': [1, 2, 3]}, {'id': 1, 'name': 'source-1', 'targets': [1, 2, 3]},
{'id': 2, 'name': u'source-2', 'targets': [1, 2]}, {'id': 2, 'name': 'source-2', 'targets': [1, 2]},
{'id': 3, 'name': u'source-3', 'targets': [1, 2, 3]} {'id': 3, 'name': 'source-3', 'targets': [1, 2, 3]}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_reverse_many_to_many_update(self): def test_reverse_many_to_many_update(self):
data = {'id': 1, 'name': u'target-1', 'sources': [1]} data = {'id': 1, 'name': 'target-1', 'sources': [1]}
instance = ManyToManyTarget.objects.get(pk=1) instance = ManyToManyTarget.objects.get(pk=1)
serializer = ManyToManyTargetSerializer(instance, data=data) serializer = ManyToManyTargetSerializer(instance, data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
@ -111,28 +113,28 @@ class PrimaryKeyManyToManyTests(TestCase):
queryset = ManyToManyTarget.objects.all() queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset) serializer = ManyToManyTargetSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'target-1', 'sources': [1]}, {'id': 1, 'name': 'target-1', 'sources': [1]},
{'id': 2, 'name': u'target-2', 'sources': [2, 3]}, {'id': 2, 'name': 'target-2', 'sources': [2, 3]},
{'id': 3, 'name': u'target-3', 'sources': [3]} {'id': 3, 'name': 'target-3', 'sources': [3]}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_reverse_many_to_many_create(self): def test_reverse_many_to_many_create(self):
data = {'id': 4, 'name': u'target-4', 'sources': [1, 3]} data = {'id': 4, 'name': 'target-4', 'sources': [1, 3]}
serializer = ManyToManyTargetSerializer(data=data) serializer = ManyToManyTargetSerializer(data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
obj = serializer.save() obj = serializer.save()
self.assertEquals(serializer.data, data) self.assertEquals(serializer.data, data)
self.assertEqual(obj.name, u'target-4') self.assertEqual(obj.name, 'target-4')
# Ensure target 4 is added, and everything else is as expected # Ensure target 4 is added, and everything else is as expected
queryset = ManyToManyTarget.objects.all() queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset) serializer = ManyToManyTargetSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]}, {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
{'id': 2, 'name': u'target-2', 'sources': [2, 3]}, {'id': 2, 'name': 'target-2', 'sources': [2, 3]},
{'id': 3, 'name': u'target-3', 'sources': [3]}, {'id': 3, 'name': 'target-3', 'sources': [3]},
{'id': 4, 'name': u'target-4', 'sources': [1, 3]} {'id': 4, 'name': 'target-4', 'sources': [1, 3]}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -151,9 +153,9 @@ class PrimaryKeyForeignKeyTests(TestCase):
queryset = ForeignKeySource.objects.all() queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset) serializer = ForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'source-1', 'target': 1}, {'id': 1, 'name': 'source-1', 'target': 1},
{'id': 2, 'name': u'source-2', 'target': 1}, {'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': u'source-3', 'target': 1} {'id': 3, 'name': 'source-3', 'target': 1}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -161,13 +163,13 @@ class PrimaryKeyForeignKeyTests(TestCase):
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset) serializer = ForeignKeyTargetSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'target-1', 'sources': [1, 2, 3]}, {'id': 1, 'name': 'target-1', 'sources': [1, 2, 3]},
{'id': 2, 'name': u'target-2', 'sources': []}, {'id': 2, 'name': 'target-2', 'sources': []},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_foreign_key_update(self): def test_foreign_key_update(self):
data = {'id': 1, 'name': u'source-1', 'target': 2} data = {'id': 1, 'name': 'source-1', 'target': 2}
instance = ForeignKeySource.objects.get(pk=1) instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data) serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
@ -178,9 +180,9 @@ class PrimaryKeyForeignKeyTests(TestCase):
queryset = ForeignKeySource.objects.all() queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset) serializer = ForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'source-1', 'target': 2}, {'id': 1, 'name': 'source-1', 'target': 2},
{'id': 2, 'name': u'source-2', 'target': 1}, {'id': 2, 'name': 'source-2', 'target': 1},
{'id': 3, 'name': u'source-3', 'target': 1} {'id': 3, 'name': 'source-3', 'target': 1}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -189,7 +191,7 @@ class PrimaryKeyForeignKeyTests(TestCase):
# and cannot be arbitrarily set. # and cannot be arbitrarily set.
# def test_reverse_foreign_key_update(self): # def test_reverse_foreign_key_update(self):
# data = {'id': 1, 'name': u'target-1', 'sources': [1]} # data = {'id': 1, 'name': 'target-1', 'sources': [1]}
# instance = ForeignKeyTarget.objects.get(pk=1) # instance = ForeignKeyTarget.objects.get(pk=1)
# serializer = ForeignKeyTargetSerializer(instance, data=data) # serializer = ForeignKeyTargetSerializer(instance, data=data)
# self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
@ -200,7 +202,7 @@ class PrimaryKeyForeignKeyTests(TestCase):
# queryset = ForeignKeyTarget.objects.all() # queryset = ForeignKeyTarget.objects.all()
# serializer = ForeignKeyTargetSerializer(queryset) # serializer = ForeignKeyTargetSerializer(queryset)
# expected = [ # expected = [
# {'id': 1, 'name': u'target-1', 'sources': [1]}, # {'id': 1, 'name': 'target-1', 'sources': [1]},
# {'id': 2, 'name': u'target-2', 'sources': []}, # {'id': 2, 'name': 'target-2', 'sources': []},
# ] # ]
# self.assertEquals(serializer.data, expected) # self.assertEquals(serializer.data, expected)

View File

@ -15,7 +15,7 @@ from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \
from rest_framework.parsers import YAMLParser, XMLParser from rest_framework.parsers import YAMLParser, XMLParser
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from StringIO import StringIO from rest_framework.compat import StringIO
import datetime import datetime
from decimal import Decimal from decimal import Decimal

View File

@ -1,3 +1,5 @@
from __future__ import unicode_literals
import datetime import datetime
from django.test import TestCase from django.test import TestCase
from rest_framework import serializers from rest_framework import serializers
@ -163,12 +165,12 @@ class ValidationTests(TestCase):
def test_create(self): def test_create(self):
serializer = CommentSerializer(data=self.data) serializer = CommentSerializer(data=self.data)
self.assertEquals(serializer.is_valid(), False) self.assertEquals(serializer.is_valid(), False)
self.assertEquals(serializer.errors, {'content': [u'Ensure this value has at most 1000 characters (it has 1001).']}) self.assertEquals(serializer.errors, {'content': ['Ensure this value has at most 1000 characters (it has 1001).']})
def test_update(self): def test_update(self):
serializer = CommentSerializer(self.comment, data=self.data) serializer = CommentSerializer(self.comment, data=self.data)
self.assertEquals(serializer.is_valid(), False) self.assertEquals(serializer.is_valid(), False)
self.assertEquals(serializer.errors, {'content': [u'Ensure this value has at most 1000 characters (it has 1001).']}) self.assertEquals(serializer.errors, {'content': ['Ensure this value has at most 1000 characters (it has 1001).']})
def test_update_missing_field(self): def test_update_missing_field(self):
data = { data = {
@ -177,7 +179,7 @@ class ValidationTests(TestCase):
} }
serializer = CommentSerializer(self.comment, data=data) serializer = CommentSerializer(self.comment, data=data)
self.assertEquals(serializer.is_valid(), False) self.assertEquals(serializer.is_valid(), False)
self.assertEquals(serializer.errors, {'email': [u'This field is required.']}) self.assertEquals(serializer.errors, {'email': ['This field is required.']})
def test_missing_bool_with_default(self): def test_missing_bool_with_default(self):
"""Make sure that a boolean value with a 'False' value is not """Make sure that a boolean value with a 'False' value is not
@ -213,7 +215,7 @@ class ValidationTests(TestCase):
serializer = CommentSerializerWithFieldValidator(data=data) serializer = CommentSerializerWithFieldValidator(data=data)
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'content': [u'Test not in value']}) self.assertEquals(serializer.errors, {'content': ['Test not in value']})
def test_cross_field_validation(self): def test_cross_field_validation(self):
@ -237,7 +239,7 @@ class ValidationTests(TestCase):
serializer = CommentSerializerWithCrossFieldValidator(data=data) serializer = CommentSerializerWithCrossFieldValidator(data=data)
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'non_field_errors': [u'Email address not in content']}) self.assertEquals(serializer.errors, {'non_field_errors': ['Email address not in content']})
def test_null_is_true_fields(self): def test_null_is_true_fields(self):
""" """
@ -253,7 +255,7 @@ class ValidationTests(TestCase):
} }
serializer = ActionItemSerializer(data=data) serializer = ActionItemSerializer(data=data)
self.assertEquals(serializer.is_valid(), False) self.assertEquals(serializer.is_valid(), False)
self.assertEquals(serializer.errors, {'title': [u'Ensure this value has at most 200 characters (it has 201).']}) self.assertEquals(serializer.errors, {'title': ['Ensure this value has at most 200 characters (it has 201).']})
def test_default_modelfield_max_length_exceeded(self): def test_default_modelfield_max_length_exceeded(self):
data = { data = {
@ -262,22 +264,22 @@ class ValidationTests(TestCase):
} }
serializer = ActionItemSerializer(data=data) serializer = ActionItemSerializer(data=data)
self.assertEquals(serializer.is_valid(), False) self.assertEquals(serializer.is_valid(), False)
self.assertEquals(serializer.errors, {'info': [u'Ensure this value has at most 12 characters (it has 13).']}) self.assertEquals(serializer.errors, {'info': ['Ensure this value has at most 12 characters (it has 13).']})
class RegexValidationTest(TestCase): class RegexValidationTest(TestCase):
def test_create_failed(self): def test_create_failed(self):
serializer = BookSerializer(data={'isbn': '1234567890'}) serializer = BookSerializer(data={'isbn': '1234567890'})
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'isbn': [u'isbn has to be exact 13 numbers']}) self.assertEquals(serializer.errors, {'isbn': ['isbn has to be exact 13 numbers']})
serializer = BookSerializer(data={'isbn': '12345678901234'}) serializer = BookSerializer(data={'isbn': '12345678901234'})
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'isbn': [u'isbn has to be exact 13 numbers']}) self.assertEquals(serializer.errors, {'isbn': ['isbn has to be exact 13 numbers']})
serializer = BookSerializer(data={'isbn': 'abcdefghijklm'}) serializer = BookSerializer(data={'isbn': 'abcdefghijklm'})
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'isbn': [u'isbn has to be exact 13 numbers']}) self.assertEquals(serializer.errors, {'isbn': ['isbn has to be exact 13 numbers']})
def test_create_success(self): def test_create_success(self):
serializer = BookSerializer(data={'isbn': '1234567890123'}) serializer = BookSerializer(data={'isbn': '1234567890123'})
@ -574,8 +576,8 @@ class SerializerMethodFieldTests(TestCase):
serializer = self.serializer_class(source_data) serializer = self.serializer_class(source_data)
expected = { expected = {
'beep': u'hello!', 'beep': 'hello!',
'boop': [u'a', u'b', u'c'], 'boop': ['a', 'b', 'c'],
'boop_count': 3, 'boop_count': 3,
} }

View File

@ -1,3 +1,5 @@
from __future__ import unicode_literals
import copy import copy
from django.test import TestCase from django.test import TestCase
from django.test.client import RequestFactory from django.test.client import RequestFactory
@ -47,7 +49,7 @@ class ClassBasedViewIntegrationTests(TestCase):
request = factory.post('/', 'f00bar', content_type='application/json') request = factory.post('/', 'f00bar', content_type='application/json')
response = self.view(request) response = self.view(request)
expected = { expected = {
'detail': u'JSON parse error - No JSON object could be decoded' 'detail': 'JSON parse error - No JSON object could be decoded'
} }
self.assertEquals(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEquals(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEquals(sanitise_json_error(response.data), expected) self.assertEquals(sanitise_json_error(response.data), expected)
@ -62,7 +64,7 @@ class ClassBasedViewIntegrationTests(TestCase):
request = factory.post('/', form_data) request = factory.post('/', form_data)
response = self.view(request) response = self.view(request)
expected = { expected = {
'detail': u'JSON parse error - No JSON object could be decoded' 'detail': 'JSON parse error - No JSON object could be decoded'
} }
self.assertEquals(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEquals(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEquals(sanitise_json_error(response.data), expected) self.assertEquals(sanitise_json_error(response.data), expected)
@ -76,7 +78,7 @@ class FunctionBasedViewIntegrationTests(TestCase):
request = factory.post('/', 'f00bar', content_type='application/json') request = factory.post('/', 'f00bar', content_type='application/json')
response = self.view(request) response = self.view(request)
expected = { expected = {
'detail': u'JSON parse error - No JSON object could be decoded' 'detail': 'JSON parse error - No JSON object could be decoded'
} }
self.assertEquals(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEquals(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEquals(sanitise_json_error(response.data), expected) self.assertEquals(sanitise_json_error(response.data), expected)
@ -91,7 +93,7 @@ class FunctionBasedViewIntegrationTests(TestCase):
request = factory.post('/', form_data) request = factory.post('/', form_data)
response = self.view(request) response = self.view(request)
expected = { expected = {
'detail': u'JSON parse error - No JSON object could be decoded' 'detail': 'JSON parse error - No JSON object could be decoded'
} }
self.assertEquals(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEquals(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEquals(sanitise_json_error(response.data), expected) self.assertEquals(sanitise_json_error(response.data), expected)

View File

@ -1,4 +1,8 @@
from django.utils.encoding import smart_unicode
try:
from django.utils.encoding import smart_text
except ImportError:
from django.utils.encoding import smart_unicode as smart_text
from django.utils.xmlutils import SimplerXMLGenerator from django.utils.xmlutils import SimplerXMLGenerator
from rest_framework.compat import StringIO from rest_framework.compat import StringIO
import re import re
@ -80,10 +84,10 @@ class XMLRenderer():
pass pass
else: else:
xml.characters(smart_unicode(data)) xml.characters(smart_text(data))
def dict2xml(self, data): def dict2xml(self, data):
stream = StringIO.StringIO() stream = StringIO()
xml = SimplerXMLGenerator(stream, "utf-8") xml = SimplerXMLGenerator(stream, "utf-8")
xml.startDocument() xml.startDocument()

View File

@ -47,7 +47,7 @@ class _MediaType(object):
if media_type_str is None: if media_type_str is None:
media_type_str = '' media_type_str = ''
self.orig = media_type_str self.orig = media_type_str
self.full_type, self.params = parse_header(media_type_str) self.full_type, self.params = parse_header(media_type_str.encode('utf8'))
self.main_type, sep, self.sub_type = self.full_type.partition('/') self.main_type, sep, self.sub_type = self.full_type.partition('/')
def match(self, other): def match(self, other):

View File

@ -63,7 +63,7 @@ setup(
packages=get_packages('rest_framework'), packages=get_packages('rest_framework'),
package_data=get_package_data('rest_framework'), package_data=get_package_data('rest_framework'),
test_suite='rest_framework.runtests.runtests.main', test_suite='rest_framework.runtests.runtests.main',
install_requires=[], install_requires=['six'],
classifiers=[ classifiers=[
'Development Status :: 4 - Beta', 'Development Status :: 4 - Beta',
'Environment :: Web Environment', 'Environment :: Web Environment',