Merge pull request #1780 from tomchristie/pytest-tweaks

Added `runtests.py` and `flake8` code linting.
This commit is contained in:
Tom Christie 2014-08-19 16:15:48 +01:00
commit 390061bed0
62 changed files with 594 additions and 354 deletions

View File

@ -19,6 +19,7 @@ install:
- pip install Pillow==2.3.0 - pip install Pillow==2.3.0
- pip install django-guardian==1.2.3 - pip install django-guardian==1.2.3
- pip install pytest-django==2.6.1 - pip install pytest-django==2.6.1
- pip install flake8==2.2.2
- "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install oauth2==1.5.211; fi" - "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install oauth2==1.5.211; fi"
- "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install django-oauth-plus==2.2.4; fi" - "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install django-oauth-plus==2.2.4; fi"
- "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install django-oauth2-provider==0.2.4; fi" - "if [[ ${TRAVIS_PYTHON_VERSION::1} != '3' ]]; then pip install django-oauth2-provider==0.2.4; fi"
@ -28,7 +29,7 @@ install:
- export PYTHONPATH=. - export PYTHONPATH=.
script: script:
- py.test - ./runtests.py
matrix: matrix:
exclude: exclude:

View File

@ -1,2 +0,0 @@
[pytest]
addopts = --tb=short

View File

@ -1,3 +1,10 @@
# Test requirements
pytest-django==2.6
pytest==2.5.2
pytest-cov==1.6
flake8==2.2.2
# Optional packages
markdown>=2.1.0 markdown>=2.1.0
PyYAML>=3.10 PyYAML>=3.10
defusedxml>=0.3 defusedxml>=0.3

View File

@ -1,3 +1 @@
-e .
Django>=1.3 Django>=1.3
pytest-django==2.6

View File

@ -1,5 +1,5 @@
""" """
______ _____ _____ _____ __ _ ______ _____ _____ _____ __
| ___ \ ___/ ___|_ _| / _| | | | ___ \ ___/ ___|_ _| / _| | |
| |_/ / |__ \ `--. | | | |_ _ __ __ _ _ __ ___ _____ _____ _ __| |__ | |_/ / |__ \ `--. | | | |_ _ __ __ _ _ __ ___ _____ _____ _ __| |__
| /| __| `--. \ | | | _| '__/ _` | '_ ` _ \ / _ \ \ /\ / / _ \| '__| |/ / | /| __| `--. \ | | | _| '__/ _` | '_ ` _ \ / _ \ \ /\ / / _ \| '__| |/ /

View File

@ -21,7 +21,7 @@ def get_authorization_header(request):
Hide some test client ickyness where the header can be unicode. Hide some test client ickyness where the header can be unicode.
""" """
auth = request.META.get('HTTP_AUTHORIZATION', b'') auth = request.META.get('HTTP_AUTHORIZATION', b'')
if type(auth) == type(''): if isinstance(auth, type('')):
# Work around django test client oddness # Work around django test client oddness
auth = auth.encode(HTTP_HEADER_ENCODING) auth = auth.encode(HTTP_HEADER_ENCODING)
return auth return auth

View File

@ -1,6 +1,5 @@
import binascii import binascii
import os import os
from hashlib import sha1
from django.conf import settings from django.conf import settings
from django.db import models from django.db import models

View File

@ -1,11 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import datetime
from south.db import db from south.db import db
from south.v2 import SchemaMigration from south.v2 import SchemaMigration
from django.db import models
from rest_framework.settings import api_settings
try: try:
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
@ -26,12 +21,10 @@ class Migration(SchemaMigration):
)) ))
db.send_create_signal('authtoken', ['Token']) db.send_create_signal('authtoken', ['Token'])
def backwards(self, orm): def backwards(self, orm):
# Deleting model 'Token' # Deleting model 'Token'
db.delete_table('authtoken_token') db.delete_table('authtoken_token')
models = { models = {
'auth.group': { 'auth.group': {
'Meta': {'object_name': 'Group'}, 'Meta': {'object_name': 'Group'},

View File

@ -131,6 +131,7 @@ def list_route(methods=['get'], **kwargs):
return func return func
return decorator return decorator
# These are now pending deprecation, in favor of `detail_route` and `list_route`. # These are now pending deprecation, in favor of `detail_route` and `list_route`.
def link(**kwargs): def link(**kwargs):
@ -139,11 +140,13 @@ def link(**kwargs):
""" """
msg = 'link is pending deprecation. Use detail_route instead.' msg = 'link is pending deprecation. Use detail_route instead.'
warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
def decorator(func): def decorator(func):
func.bind_to_methods = ['get'] func.bind_to_methods = ['get']
func.detail = True func.detail = True
func.kwargs = kwargs func.kwargs = kwargs
return func return func
return decorator return decorator
@ -153,9 +156,11 @@ def action(methods=['post'], **kwargs):
""" """
msg = 'action is pending deprecation. Use detail_route instead.' msg = 'action is pending deprecation. Use detail_route instead.'
warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
def decorator(func): def decorator(func):
func.bind_to_methods = methods func.bind_to_methods = methods
func.detail = True func.detail = True
func.kwargs = kwargs func.kwargs = kwargs
return func return func
return decorator return decorator

View File

@ -23,6 +23,7 @@ class APIException(Exception):
def __str__(self): def __str__(self):
return self.detail return self.detail
class ParseError(APIException): class ParseError(APIException):
status_code = status.HTTP_400_BAD_REQUEST status_code = status.HTTP_400_BAD_REQUEST
default_detail = 'Malformed request.' default_detail = 'Malformed request.'

View File

@ -63,8 +63,10 @@ def get_component(obj, attr_name):
def readable_datetime_formats(formats): def readable_datetime_formats(formats):
format = ', '.join(formats).replace(ISO_8601, format = ', '.join(formats).replace(
'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]') ISO_8601,
'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]'
)
return humanize_strptime(format) return humanize_strptime(format)
@ -425,7 +427,7 @@ class ModelField(WritableField):
} }
##### Typed Fields ##### # Typed Fields
class BooleanField(WritableField): class BooleanField(WritableField):
type_name = 'BooleanField' type_name = 'BooleanField'
@ -484,7 +486,7 @@ class URLField(CharField):
type_label = 'url' type_label = 'url'
def __init__(self, **kwargs): def __init__(self, **kwargs):
if not 'validators' in kwargs: if 'validators' not in kwargs:
kwargs['validators'] = [validators.URLValidator()] kwargs['validators'] = [validators.URLValidator()]
super(URLField, self).__init__(**kwargs) super(URLField, self).__init__(**kwargs)

View File

@ -25,6 +25,7 @@ def strict_positive_int(integer_string, cutoff=None):
ret = min(ret, cutoff) ret = min(ret, cutoff)
return ret return ret
def get_object_or_404(queryset, *filter_args, **filter_kwargs): def get_object_or_404(queryset, *filter_args, **filter_kwargs):
""" """
Same as Django's standard shortcut, but make sure to raise 404 Same as Django's standard shortcut, but make sure to raise 404
@ -162,10 +163,11 @@ class GenericAPIView(views.APIView):
raise Http404(_("Page is not 'last', nor can it be converted to an int.")) raise Http404(_("Page is not 'last', nor can it be converted to an int."))
try: try:
page = paginator.page(page_number) page = paginator.page(page_number)
except InvalidPage as e: except InvalidPage as exc:
raise Http404(_('Invalid page (%(page_number)s): %(message)s') % { error_format = _('Invalid page (%(page_number)s): %(message)s')
raise Http404(error_format % {
'page_number': page_number, 'page_number': page_number,
'message': str(e) 'message': str(exc)
}) })
if deprecated_style: if deprecated_style:
@ -208,10 +210,8 @@ class GenericAPIView(views.APIView):
return filter_backends return filter_backends
# The following methods provide default implementations
######################## # that you may want to override for more complex cases.
### The following methods provide default implementations
### that you may want to override for more complex cases.
def get_paginate_by(self, queryset=None): def get_paginate_by(self, queryset=None):
""" """
@ -284,8 +284,8 @@ class GenericAPIView(views.APIView):
if self.model is not None: if self.model is not None:
return self.model._default_manager.all() return self.model._default_manager.all()
raise ImproperlyConfigured("'%s' must define 'queryset' or 'model'" error_format = "'%s' must define 'queryset' or 'model'"
% self.__class__.__name__) raise ImproperlyConfigured(error_format % self.__class__.__name__)
def get_object(self, queryset=None): def get_object(self, queryset=None):
""" """
@ -339,12 +339,11 @@ class GenericAPIView(views.APIView):
return obj return obj
######################## # The following are placeholder methods,
### The following are placeholder methods, # and are intended to be overridden.
### and are intended to be overridden. #
### # The are not called by GenericAPIView directly,
### The are not called by GenericAPIView directly, # but are used by the mixin methods.
### but are used by the mixin methods.
def pre_save(self, obj): def pre_save(self, obj):
""" """
@ -416,10 +415,8 @@ class GenericAPIView(views.APIView):
return ret return ret
########################################################## # Concrete view classes that provide method handlers
### Concrete view classes that provide method handlers ### # by composing the mixin classes with the base view.
### by composing the mixin classes with the base view. ###
##########################################################
class CreateAPIView(mixins.CreateModelMixin, class CreateAPIView(mixins.CreateModelMixin,
GenericAPIView): GenericAPIView):
@ -534,9 +531,7 @@ class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin,
return self.destroy(request, *args, **kwargs) return self.destroy(request, *args, **kwargs)
########################## # Deprecated classes
### Deprecated classes ###
##########################
class MultipleObjectAPIView(GenericAPIView): class MultipleObjectAPIView(GenericAPIView):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):

View File

@ -54,8 +54,10 @@ class DefaultContentNegotiation(BaseContentNegotiation):
for media_type in media_type_set: for media_type in media_type_set:
if media_type_matches(renderer.media_type, media_type): if media_type_matches(renderer.media_type, media_type):
# Return the most specific media type as accepted. # Return the most specific media type as accepted.
if (_MediaType(renderer.media_type).precedence > if (
_MediaType(media_type).precedence): _MediaType(renderer.media_type).precedence >
_MediaType(media_type).precedence
):
# Eg client requests '*/*' # Eg client requests '*/*'
# Accepted media type is 'application/json' # Accepted media type is 'application/json'
return renderer, renderer.media_type return renderer, renderer.media_type

View File

@ -62,9 +62,11 @@ class IsAuthenticatedOrReadOnly(BasePermission):
""" """
def has_permission(self, request, view): def has_permission(self, request, view):
return (request.method in SAFE_METHODS or return (
request.method in SAFE_METHODS or
request.user and request.user and
request.user.is_authenticated()) request.user.is_authenticated()
)
class DjangoModelPermissions(BasePermission): class DjangoModelPermissions(BasePermission):
@ -122,9 +124,11 @@ class DjangoModelPermissions(BasePermission):
perms = self.get_required_permissions(request.method, model_cls) perms = self.get_required_permissions(request.method, model_cls)
return (request.user and return (
request.user and
(request.user.is_authenticated() or not self.authenticated_users_only) and (request.user.is_authenticated() or not self.authenticated_users_only) and
request.user.has_perms(perms)) request.user.has_perms(perms)
)
class DjangoModelPermissionsOrAnonReadOnly(DjangoModelPermissions): class DjangoModelPermissionsOrAnonReadOnly(DjangoModelPermissions):
@ -212,6 +216,8 @@ class TokenHasReadWriteScope(BasePermission):
required = oauth2_constants.READ if read_only else oauth2_constants.WRITE required = oauth2_constants.READ if read_only else oauth2_constants.WRITE
return oauth2_provider_scope.check(required, request.auth.scope) return oauth2_provider_scope.check(required, request.auth.scope)
assert False, ('TokenHasReadWriteScope requires either the' assert False, (
'TokenHasReadWriteScope requires either the'
'`OAuthAuthentication` or `OAuth2Authentication` authentication ' '`OAuthAuthentication` or `OAuth2Authentication` authentication '
'class to be used.') 'class to be used.'
)

View File

@ -19,8 +19,7 @@ from rest_framework.compat import smart_text
import warnings import warnings
##### Relational fields ##### # Relational fields
# Not actually Writable, but subclasses may need to be. # Not actually Writable, but subclasses may need to be.
class RelatedField(WritableField): class RelatedField(WritableField):
@ -66,7 +65,7 @@ class RelatedField(WritableField):
else: # Reverse else: # Reverse
self.queryset = manager.field.rel.to._default_manager.all() self.queryset = manager.field.rel.to._default_manager.all()
### We need this stuff to make form choices work... # We need this stuff to make form choices work...
def prepare_value(self, obj): def prepare_value(self, obj):
return self.to_native(obj) return self.to_native(obj)
@ -113,7 +112,7 @@ class RelatedField(WritableField):
choices = property(_get_choices, _set_choices) choices = property(_get_choices, _set_choices)
### Default value handling # Default value handling
def get_default_value(self): def get_default_value(self):
default = super(RelatedField, self).get_default_value() default = super(RelatedField, self).get_default_value()
@ -121,7 +120,7 @@ class RelatedField(WritableField):
return [] return []
return default return default
### Regular serializer stuff... # Regular serializer stuff...
def field_to_native(self, obj, field_name): def field_to_native(self, obj, field_name):
try: try:
@ -181,7 +180,7 @@ class RelatedField(WritableField):
into[(self.source or field_name)] = self.from_native(value) into[(self.source or field_name)] = self.from_native(value)
### PrimaryKey relationships # PrimaryKey relationships
class PrimaryKeyRelatedField(RelatedField): class PrimaryKeyRelatedField(RelatedField):
""" """
@ -269,8 +268,7 @@ class PrimaryKeyRelatedField(RelatedField):
return self.to_native(pk) return self.to_native(pk)
### Slug relationships # Slug relationships
class SlugRelatedField(RelatedField): class SlugRelatedField(RelatedField):
""" """
@ -305,7 +303,7 @@ class SlugRelatedField(RelatedField):
raise ValidationError(msg) raise ValidationError(msg)
### Hyperlinked relationships # Hyperlinked relationships
class HyperlinkedRelatedField(RelatedField): class HyperlinkedRelatedField(RelatedField):
""" """

View File

@ -8,7 +8,6 @@ REST framework also provides an HTML renderer the renders the browsable API.
""" """
from __future__ import unicode_literals from __future__ import unicode_literals
import copy
import json import json
import django import django
from django import forms from django import forms
@ -75,7 +74,6 @@ class JSONRenderer(BaseRenderer):
# E.g. If we're being called by the BrowsableAPIRenderer. # E.g. If we're being called by the BrowsableAPIRenderer.
return renderer_context.get('indent', None) return renderer_context.get('indent', None)
def render(self, data, accepted_media_type=None, renderer_context=None): def render(self, data, accepted_media_type=None, renderer_context=None):
""" """
Render `data` into JSON, returning a bytestring. Render `data` into JSON, returning a bytestring.
@ -86,8 +84,10 @@ class JSONRenderer(BaseRenderer):
renderer_context = renderer_context or {} renderer_context = renderer_context or {}
indent = self.get_indent(accepted_media_type, renderer_context) indent = self.get_indent(accepted_media_type, renderer_context)
ret = json.dumps(data, cls=self.encoder_class, ret = json.dumps(
indent=indent, ensure_ascii=self.ensure_ascii) data, cls=self.encoder_class,
indent=indent, ensure_ascii=self.ensure_ascii
)
# On python 2.x json.dumps() returns bytestrings if ensure_ascii=True, # On python 2.x json.dumps() returns bytestrings if ensure_ascii=True,
# but if ensure_ascii=False, the return type is underspecified, # but if ensure_ascii=False, the return type is underspecified,
@ -414,7 +414,7 @@ class BrowsableAPIRenderer(BaseRenderer):
""" """
Returns True if a form should be shown for this method. Returns True if a form should be shown for this method.
""" """
if not method in view.allowed_methods: if method not in view.allowed_methods:
return # Not a valid method return # Not a valid method
if not api_settings.FORM_METHOD_OVERRIDE: if not api_settings.FORM_METHOD_OVERRIDE:
@ -454,8 +454,10 @@ class BrowsableAPIRenderer(BaseRenderer):
if method in ('DELETE', 'OPTIONS'): if method in ('DELETE', 'OPTIONS'):
return True # Don't actually need to return a form return True # Don't actually need to return a form
if (not getattr(view, 'get_serializer', None) if (
or not any(is_form_media_type(parser.media_type) for parser in view.parser_classes)): not getattr(view, 'get_serializer', None)
or not any(is_form_media_type(parser.media_type) for parser in view.parser_classes)
):
return return
serializer = view.get_serializer(instance=obj, data=data, files=files) serializer = view.get_serializer(instance=obj, data=data, files=files)
@ -576,7 +578,7 @@ class BrowsableAPIRenderer(BaseRenderer):
'version': VERSION, 'version': VERSION,
'breadcrumblist': self.get_breadcrumbs(request), 'breadcrumblist': self.get_breadcrumbs(request),
'allowed_methods': view.allowed_methods, 'allowed_methods': view.allowed_methods,
'available_formats': [renderer.format for renderer in view.renderer_classes], 'available_formats': [renderer_cls.format for renderer_cls in view.renderer_classes],
'response_headers': response_headers, 'response_headers': response_headers,
'put_form': self.get_rendered_html_form(view, 'PUT', request), 'put_form': self.get_rendered_html_form(view, 'PUT', request),
@ -625,4 +627,3 @@ class MultiPartRenderer(BaseRenderer):
def render(self, data, accepted_media_type=None, renderer_context=None): def render(self, data, accepted_media_type=None, renderer_context=None):
return encode_multipart(self.BOUNDARY, data) return encode_multipart(self.BOUNDARY, data)

View File

@ -295,8 +295,11 @@ class Request(object):
Return the content body of the request, as a stream. Return the content body of the request, as a stream.
""" """
try: try:
content_length = int(self.META.get('CONTENT_LENGTH', content_length = int(
self.META.get('HTTP_CONTENT_LENGTH'))) self.META.get(
'CONTENT_LENGTH', self.META.get('HTTP_CONTENT_LENGTH')
)
)
except (ValueError, TypeError): except (ValueError, TypeError):
content_length = 0 content_length = 0
@ -320,9 +323,11 @@ class Request(object):
) )
# We only need to use form overloading on form POST requests. # We only need to use form overloading on form POST requests.
if (not USE_FORM_OVERLOADING if (
not USE_FORM_OVERLOADING
or self._request.method != 'POST' or self._request.method != 'POST'
or not is_form_media_type(self._content_type)): or not is_form_media_type(self._content_type)
):
return return
# At this point we're committed to parsing the request as form data. # At this point we're committed to parsing the request as form data.
@ -330,15 +335,19 @@ class Request(object):
self._files = self._request.FILES self._files = self._request.FILES
# Method overloading - change the method and remove the param from the content. # Method overloading - change the method and remove the param from the content.
if (self._METHOD_PARAM and if (
self._METHOD_PARAM in self._data): self._METHOD_PARAM and
self._METHOD_PARAM in self._data
):
self._method = self._data[self._METHOD_PARAM].upper() self._method = self._data[self._METHOD_PARAM].upper()
# Content overloading - modify the content type, and force re-parse. # Content overloading - modify the content type, and force re-parse.
if (self._CONTENT_PARAM and if (
self._CONTENT_PARAM and
self._CONTENTTYPE_PARAM and self._CONTENTTYPE_PARAM and
self._CONTENT_PARAM in self._data and self._CONTENT_PARAM in self._data and
self._CONTENTTYPE_PARAM in self._data): self._CONTENTTYPE_PARAM in self._data
):
self._content_type = self._data[self._CONTENTTYPE_PARAM] self._content_type = self._data[self._CONTENTTYPE_PARAM]
self._stream = BytesIO(self._data[self._CONTENT_PARAM].encode(self.parser_context['encoding'])) self._stream = BytesIO(self._data[self._CONTENT_PARAM].encode(self.parser_context['encoding']))
self._data, self._files = (Empty, Empty) self._data, self._files = (Empty, Empty)
@ -394,7 +403,7 @@ class Request(object):
self._not_authenticated() self._not_authenticated()
raise raise
if not user_auth_tuple is None: if user_auth_tuple is not None:
self._authenticator = authenticator self._authenticator = authenticator
self._user, self._auth = user_auth_tuple self._user, self._auth = user_auth_tuple
return return

View File

@ -62,8 +62,10 @@ class Response(SimpleTemplateResponse):
ret = renderer.render(self.data, media_type, context) ret = renderer.render(self.data, media_type, context)
if isinstance(ret, six.text_type): if isinstance(ret, six.text_type):
assert charset, 'renderer returned unicode, and did not specify ' \ assert charset, (
'renderer returned unicode, and did not specify '
'a charset value.' 'a charset value.'
)
return bytes(ret.encode(charset)) return bytes(ret.encode(charset))
if not ret: if not ret:

View File

@ -449,9 +449,11 @@ class BaseSerializer(WritableField):
# If we have a model manager or similar object then we need # If we have a model manager or similar object then we need
# to iterate through each instance. # to iterate through each instance.
if (self.many and if (
self.many and
not hasattr(obj, '__iter__') and not hasattr(obj, '__iter__') and
is_simple_callable(getattr(obj, 'all', None))): is_simple_callable(getattr(obj, 'all', None))
):
obj = obj.all() obj = obj.all()
kwargs = { kwargs = {
@ -601,8 +603,10 @@ class BaseSerializer(WritableField):
API schemas for auto-documentation. API schemas for auto-documentation.
""" """
return SortedDict( return SortedDict(
[(field_name, field.metadata()) [
for field_name, field in six.iteritems(self.fields)] (field_name, field.metadata())
for field_name, field in six.iteritems(self.fields)
]
) )
@ -656,8 +660,10 @@ class ModelSerializer(Serializer):
""" """
cls = self.opts.model cls = self.opts.model
assert cls is not None, \ assert cls is not None, (
"Serializer class '%s' is missing 'model' Meta option" % self.__class__.__name__ "Serializer class '%s' is missing 'model' Meta option" %
self.__class__.__name__
)
opts = cls._meta.concrete_model._meta opts = cls._meta.concrete_model._meta
ret = SortedDict() ret = SortedDict()
nested = bool(self.opts.depth) nested = bool(self.opts.depth)
@ -668,9 +674,9 @@ class ModelSerializer(Serializer):
# If model is a child via multitable inheritance, use parent's pk # If model is a child via multitable inheritance, use parent's pk
pk_field = pk_field.rel.to._meta.pk pk_field = pk_field.rel.to._meta.pk
field = self.get_pk_field(pk_field) serializer_pk_field = self.get_pk_field(pk_field)
if field: if serializer_pk_field:
ret[pk_field.name] = field ret[pk_field.name] = serializer_pk_field
# Deal with forward relationships # Deal with forward relationships
forward_rels = [field for field in opts.fields if field.serialize] forward_rels = [field for field in opts.fields if field.serialize]
@ -739,9 +745,11 @@ class ModelSerializer(Serializer):
is_m2m = isinstance(relation.field, is_m2m = isinstance(relation.field,
models.fields.related.ManyToManyField) models.fields.related.ManyToManyField)
if (is_m2m and if (
is_m2m and
hasattr(relation.field.rel, 'through') and hasattr(relation.field.rel, 'through') and
not relation.field.rel.through._meta.auto_created): not relation.field.rel.through._meta.auto_created
):
has_through_model = True has_through_model = True
if nested: if nested:
@ -911,10 +919,12 @@ class ModelSerializer(Serializer):
for field_name, field in self.fields.items(): for field_name, field in self.fields.items():
field_name = field.source or field_name field_name = field.source or field_name
if field_name in exclusions \ if (
and not field.read_only \ field_name in exclusions
and (field.required or hasattr(instance, field_name)) \ and not field.read_only
and not isinstance(field, Serializer): and (field.required or hasattr(instance, field_name))
and not isinstance(field, Serializer)
):
exclusions.remove(field_name) exclusions.remove(field_name)
return exclusions return exclusions

View File

@ -46,16 +46,12 @@ DEFAULTS = {
'DEFAULT_PERMISSION_CLASSES': ( 'DEFAULT_PERMISSION_CLASSES': (
'rest_framework.permissions.AllowAny', 'rest_framework.permissions.AllowAny',
), ),
'DEFAULT_THROTTLE_CLASSES': ( 'DEFAULT_THROTTLE_CLASSES': (),
), 'DEFAULT_CONTENT_NEGOTIATION_CLASS': 'rest_framework.negotiation.DefaultContentNegotiation',
'DEFAULT_CONTENT_NEGOTIATION_CLASS':
'rest_framework.negotiation.DefaultContentNegotiation',
# Genric view behavior # Genric view behavior
'DEFAULT_MODEL_SERIALIZER_CLASS': 'DEFAULT_MODEL_SERIALIZER_CLASS': 'rest_framework.serializers.ModelSerializer',
'rest_framework.serializers.ModelSerializer', 'DEFAULT_PAGINATION_SERIALIZER_CLASS': 'rest_framework.pagination.PaginationSerializer',
'DEFAULT_PAGINATION_SERIALIZER_CLASS':
'rest_framework.pagination.PaginationSerializer',
'DEFAULT_FILTER_BACKENDS': (), 'DEFAULT_FILTER_BACKENDS': (),
# Throttling # Throttling

View File

@ -10,15 +10,19 @@ from __future__ import unicode_literals
def is_informational(code): def is_informational(code):
return code >= 100 and code <= 199 return code >= 100 and code <= 199
def is_success(code): def is_success(code):
return code >= 200 and code <= 299 return code >= 200 and code <= 299
def is_redirect(code): def is_redirect(code):
return code >= 300 and code <= 399 return code >= 300 and code <= 399
def is_client_error(code): def is_client_error(code):
return code >= 400 and code <= 499 return code >= 400 and code <= 499
def is_server_error(code): def is_server_error(code):
return code >= 500 and code <= 599 return code >= 500 and code <= 599

View File

@ -152,8 +152,10 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
middle = middle[len(opening):] middle = middle[len(opening):]
lead = lead + opening lead = lead + opening
# Keep parentheses at the end only if they're balanced. # Keep parentheses at the end only if they're balanced.
if (middle.endswith(closing) if (
and middle.count(closing) == middle.count(opening) + 1): middle.endswith(closing)
and middle.count(closing) == middle.count(opening) + 1
):
middle = middle[:-len(closing)] middle = middle[:-len(closing)]
trail = closing + trail trail = closing + trail
@ -164,7 +166,7 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
url = smart_urlquote_wrapper(middle) url = smart_urlquote_wrapper(middle)
elif simple_url_2_re.match(middle): elif simple_url_2_re.match(middle):
url = smart_urlquote_wrapper('http://%s' % middle) url = smart_urlquote_wrapper('http://%s' % middle)
elif not ':' in middle and simple_email_re.match(middle): elif ':' not in middle and simple_email_re.match(middle):
local, domain = middle.rsplit('@', 1) local, domain = middle.rsplit('@', 1)
try: try:
domain = domain.encode('idna').decode('ascii') domain = domain.encode('idna').decode('ascii')

View File

@ -49,9 +49,10 @@ class APIRequestFactory(DjangoRequestFactory):
else: else:
format = format or self.default_format format = format or self.default_format
assert format in self.renderer_classes, ("Invalid format '{0}'. " assert format in self.renderer_classes, (
"Available formats are {1}. Set TEST_REQUEST_RENDERER_CLASSES " "Invalid format '{0}'. Available formats are {1}. "
"to enable extra request formats.".format( "Set TEST_REQUEST_RENDERER_CLASSES to enable "
"extra request formats.".format(
format, format,
', '.join(["'" + fmt + "'" for fmt in self.renderer_classes.keys()]) ', '.join(["'" + fmt + "'" for fmt in self.renderer_classes.keys()])
) )

View File

@ -14,11 +14,13 @@ your authentication settings include `SessionAuthentication`.
""" """
from __future__ import unicode_literals from __future__ import unicode_literals
from django.conf.urls import patterns, url from django.conf.urls import patterns, url
from django.contrib.auth import views
template_name = {'template_name': 'rest_framework/login.html'} template_name = {'template_name': 'rest_framework/login.html'}
urlpatterns = patterns('django.contrib.auth.views', urlpatterns = patterns(
url(r'^login/$', 'login', template_name, name='login'), '',
url(r'^logout/$', 'logout', template_name, name='logout'), url(r'^login/$', views.login, template_name, name='login'),
url(r'^logout/$', views.logout, template_name, name='logout')
) )

View File

@ -98,14 +98,23 @@ else:
node.flow_style = best_style node.flow_style = best_style
return node return node
SafeDumper.add_representer(decimal.Decimal, SafeDumper.add_representer(
SafeDumper.represent_decimal) decimal.Decimal,
SafeDumper.represent_decimal
SafeDumper.add_representer(SortedDict, )
yaml.representer.SafeRepresenter.represent_dict) SafeDumper.add_representer(
SafeDumper.add_representer(DictWithMetadata, SortedDict,
yaml.representer.SafeRepresenter.represent_dict) yaml.representer.SafeRepresenter.represent_dict
SafeDumper.add_representer(SortedDictWithMetadata, )
yaml.representer.SafeRepresenter.represent_dict) SafeDumper.add_representer(
SafeDumper.add_representer(types.GeneratorType, DictWithMetadata,
yaml.representer.SafeRepresenter.represent_list) yaml.representer.SafeRepresenter.represent_dict
)
SafeDumper.add_representer(
SortedDictWithMetadata,
yaml.representer.SafeRepresenter.represent_dict
)
SafeDumper.add_representer(
types.GeneratorType,
yaml.representer.SafeRepresenter.represent_list
)

View File

@ -6,8 +6,6 @@ from __future__ import unicode_literals
from django.utils.html import escape from django.utils.html import escape
from django.utils.safestring import mark_safe from django.utils.safestring import mark_safe
from rest_framework.compat import apply_markdown from rest_framework.compat import apply_markdown
from rest_framework.settings import api_settings
from textwrap import dedent
import re import re
@ -40,6 +38,7 @@ def dedent(content):
return content.strip() return content.strip()
def camelcase_to_spaces(content): def camelcase_to_spaces(content):
""" """
Translate 'CamelCaseNames' to 'Camel Case Names'. Translate 'CamelCaseNames' to 'Camel Case Names'.
@ -49,6 +48,7 @@ def camelcase_to_spaces(content):
content = re.sub(camelcase_boundry, ' \\1', content).strip() content = re.sub(camelcase_boundry, ' \\1', content).strip()
return ' '.join(content.split('_')).title() return ' '.join(content.split('_')).title()
def markup_description(description): def markup_description(description):
""" """
Apply HTML markup to the given description. Apply HTML markup to the given description.

View File

@ -79,7 +79,7 @@ class _MediaType(object):
return 3 return 3
def __str__(self): def __str__(self):
return unicode(self).encode('utf-8') return self.__unicode__().encode('utf-8')
def __unicode__(self): def __unicode__(self):
ret = "%s/%s" % (self.main_type, self.sub_type) ret = "%s/%s" % (self.main_type, self.sub_type)

View File

@ -31,6 +31,7 @@ def get_view_name(view_cls, suffix=None):
return name return name
def get_view_description(view_cls, html=False): def get_view_description(view_cls, html=False):
""" """
Given a view class, return a textual description to represent the view. Given a view class, return a textual description to represent the view.
@ -119,7 +120,6 @@ class APIView(View):
headers['Vary'] = 'Accept' headers['Vary'] = 'Accept'
return headers return headers
def http_method_not_allowed(self, request, *args, **kwargs): def http_method_not_allowed(self, request, *args, **kwargs):
""" """
If `request.method` does not correspond to a handler method, If `request.method` does not correspond to a handler method,

86
runtests.py Executable file
View File

@ -0,0 +1,86 @@
#! /usr/bin/env python
from __future__ import print_function
import pytest
import sys
import os
import subprocess
PYTEST_ARGS = {
'default': ['tests'],
'fast': ['tests', '-q'],
}
FLAKE8_ARGS = ['rest_framework', 'tests', '--ignore=E501']
sys.path.append(os.path.dirname(__file__))
def exit_on_failure(ret, message=None):
if ret:
sys.exit(ret)
def flake8_main(args):
print('Running flake8 code linting')
ret = subprocess.call(['flake8'] + args)
print('flake8 failed' if ret else 'flake8 passed')
return ret
def split_class_and_function(string):
class_string, function_string = string.split('.', 1)
return "%s and %s" % (class_string, function_string)
def is_function(string):
# `True` if it looks like a test function is included in the string.
return string.startswith('test_') or '.test_' in string
def is_class(string):
# `True` if first character is uppercase - assume it's a class name.
return string[0] == string[0].upper()
if __name__ == "__main__":
try:
sys.argv.remove('--nolint')
except ValueError:
run_flake8 = True
else:
run_flake8 = False
try:
sys.argv.remove('--lintonly')
except ValueError:
run_tests = True
else:
run_tests = False
try:
sys.argv.remove('--fast')
except ValueError:
style = 'default'
else:
style = 'fast'
run_flake8 = False
if len(sys.argv) > 1:
pytest_args = sys.argv[1:]
first_arg = pytest_args[0]
if first_arg.startswith('-'):
# `runtests.py [flags]`
pytest_args = ['tests'] + pytest_args
elif is_class(first_arg) and is_function(first_arg):
# `runtests.py TestCase.test_function [flags]`
expression = split_class_and_function(first_arg)
pytest_args = ['tests', '-k', expression] + pytest_args[1:]
elif is_class(first_arg) or is_function(first_arg):
# `runtests.py TestCase [flags]`
# `runtests.py test_function [flags]`
pytest_args = ['tests', '-k', pytest_args[0]] + pytest_args[1:]
else:
pytest_args = PYTEST_ARGS[style]
if run_tests:
exit_on_failure(pytest.main(pytest_args))
if run_flake8:
exit_on_failure(flake8_main(FLAKE8_ARGS))

View File

@ -47,8 +47,8 @@ def pytest_configure():
) )
try: try:
import oauth_provider import oauth_provider # NOQA
import oauth2 import oauth2 # NOQA
except ImportError: except ImportError:
pass pass
else: else:
@ -57,7 +57,7 @@ def pytest_configure():
) )
try: try:
import provider import provider # NOQA
except ImportError: except ImportError:
pass pass
else: else:
@ -68,13 +68,13 @@ def pytest_configure():
# guardian is optional # guardian is optional
try: try:
import guardian import guardian # NOQA
except ImportError: except ImportError:
pass pass
else: else:
settings.ANONYMOUS_USER_ID = -1 settings.ANONYMOUS_USER_ID = -1
settings.AUTHENTICATION_BACKENDS = ( settings.AUTHENTICATION_BACKENDS = (
'django.contrib.auth.backends.ModelBackend', # default 'django.contrib.auth.backends.ModelBackend',
'guardian.backends.ObjectPermissionBackend', 'guardian.backends.ObjectPermissionBackend',
) )
settings.INSTALLED_APPS += ( settings.INSTALLED_APPS += (

View File

@ -1,5 +1,4 @@
from rest_framework import serializers from rest_framework import serializers
from tests.models import NullableForeignKeySource from tests.models import NullableForeignKeySource

View File

@ -68,7 +68,6 @@ SECRET_KEY = 'u@x-aj9(hoh#rb-^ymf#g2jx_hp0vj7u5#b@ag1n^seu9e!%cy'
TEMPLATE_LOADERS = ( TEMPLATE_LOADERS = (
'django.template.loaders.filesystem.Loader', 'django.template.loaders.filesystem.Loader',
'django.template.loaders.app_directories.Loader', 'django.template.loaders.app_directories.Loader',
# 'django.template.loaders.eggs.Loader',
) )
MIDDLEWARE_CLASSES = ( MIDDLEWARE_CLASSES = (
@ -104,8 +103,8 @@ INSTALLED_APPS = (
# OAuth is optional and won't work if there is no oauth_provider & oauth2 # OAuth is optional and won't work if there is no oauth_provider & oauth2
try: try:
import oauth_provider import oauth_provider # NOQA
import oauth2 import oauth2 # NOQA
except ImportError: except ImportError:
pass pass
else: else:
@ -114,7 +113,7 @@ else:
) )
try: try:
import provider import provider # NOQA
except ImportError: except ImportError:
pass pass
else: else:
@ -125,7 +124,7 @@ else:
# guardian is optional # guardian is optional
try: try:
import guardian import guardian # NOQA
except ImportError: except ImportError:
pass pass
else: else:

View File

@ -45,26 +45,39 @@ class MockView(APIView):
return HttpResponse({'a': 1, 'b': 2, 'c': 3}) return HttpResponse({'a': 1, 'b': 2, 'c': 3})
urlpatterns = patterns('', urlpatterns = patterns(
'',
(r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])), (r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])),
(r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])), (r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])),
(r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])), (r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])),
(r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'), (r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'),
(r'^oauth/$', MockView.as_view(authentication_classes=[OAuthAuthentication])), (r'^oauth/$', MockView.as_view(authentication_classes=[OAuthAuthentication])),
(r'^oauth-with-scope/$', MockView.as_view(authentication_classes=[OAuthAuthentication], (
permission_classes=[permissions.TokenHasReadWriteScope])) r'^oauth-with-scope/$',
MockView.as_view(
authentication_classes=[OAuthAuthentication],
permission_classes=[permissions.TokenHasReadWriteScope]
) )
)
)
class OAuth2AuthenticationDebug(OAuth2Authentication): class OAuth2AuthenticationDebug(OAuth2Authentication):
allow_query_params_token = True allow_query_params_token = True
if oauth2_provider is not None: if oauth2_provider is not None:
urlpatterns += patterns('', urlpatterns += patterns(
'',
url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')), url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')),
url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])), url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])),
url(r'^oauth2-test-debug/$', MockView.as_view(authentication_classes=[OAuth2AuthenticationDebug])), url(r'^oauth2-test-debug/$', MockView.as_view(authentication_classes=[OAuth2AuthenticationDebug])),
url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication], url(
permission_classes=[permissions.TokenHasReadWriteScope])), r'^oauth2-with-scope-test/$',
MockView.as_view(
authentication_classes=[OAuth2Authentication],
permission_classes=[permissions.TokenHasReadWriteScope]
)
)
) )
@ -278,12 +291,16 @@ class OAuthTests(TestCase):
self.TOKEN_KEY = "token_key" self.TOKEN_KEY = "token_key"
self.TOKEN_SECRET = "token_secret" self.TOKEN_SECRET = "token_secret"
self.consumer = Consumer.objects.create(key=self.CONSUMER_KEY, secret=self.CONSUMER_SECRET, self.consumer = Consumer.objects.create(
name='example', user=self.user, status=self.consts.ACCEPTED) key=self.CONSUMER_KEY, secret=self.CONSUMER_SECRET,
name='example', user=self.user, status=self.consts.ACCEPTED
)
self.scope = Scope.objects.create(name="resource name", url="api/") self.scope = Scope.objects.create(name="resource name", url="api/")
self.token = OAuthToken.objects.create(user=self.user, consumer=self.consumer, scope=self.scope, self.token = OAuthToken.objects.create(
token_type=OAuthToken.ACCESS, key=self.TOKEN_KEY, secret=self.TOKEN_SECRET, is_approved=True user=self.user, consumer=self.consumer, scope=self.scope,
token_type=OAuthToken.ACCESS, key=self.TOKEN_KEY, secret=self.TOKEN_SECRET,
is_approved=True
) )
def _create_authorization_header(self): def _create_authorization_header(self):
@ -569,8 +586,10 @@ class OAuth2Tests(TestCase):
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_post_form_passing_auth_url_transport(self): def test_post_form_passing_auth_url_transport(self):
"""Ensure GETing form over OAuth with correct client credentials in form data succeed""" """Ensure GETing form over OAuth with correct client credentials in form data succeed"""
response = self.csrf_client.post('/oauth2-test/', response = self.csrf_client.post(
data={'access_token': self.access_token.token}) '/oauth2-test/',
data={'access_token': self.access_token.token}
)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')

View File

@ -24,7 +24,8 @@ class NestedResourceRoot(APIView):
class NestedResourceInstance(APIView): class NestedResourceInstance(APIView):
pass pass
urlpatterns = patterns('', urlpatterns = patterns(
'',
url(r'^$', Root.as_view()), url(r'^$', Root.as_view()),
url(r'^resource/$', ResourceRoot.as_view()), url(r'^resource/$', ResourceRoot.as_view()),
url(r'^resource/(?P<key>[0-9]+)$', ResourceInstance.as_view()), url(r'^resource/(?P<key>[0-9]+)$', ResourceInstance.as_view()),
@ -40,34 +41,60 @@ class BreadcrumbTests(TestCase):
def test_root_breadcrumbs(self): def test_root_breadcrumbs(self):
url = '/' url = '/'
self.assertEqual(get_breadcrumbs(url), [('Root', '/')]) self.assertEqual(
get_breadcrumbs(url),
[('Root', '/')]
)
def test_resource_root_breadcrumbs(self): def test_resource_root_breadcrumbs(self):
url = '/resource/' url = '/resource/'
self.assertEqual(get_breadcrumbs(url), [('Root', '/'), self.assertEqual(
('Resource Root', '/resource/')]) get_breadcrumbs(url),
[
('Root', '/'),
('Resource Root', '/resource/')
]
)
def test_resource_instance_breadcrumbs(self): def test_resource_instance_breadcrumbs(self):
url = '/resource/123' url = '/resource/123'
self.assertEqual(get_breadcrumbs(url), [('Root', '/'), self.assertEqual(
get_breadcrumbs(url),
[
('Root', '/'),
('Resource Root', '/resource/'), ('Resource Root', '/resource/'),
('Resource Instance', '/resource/123')]) ('Resource Instance', '/resource/123')
]
)
def test_nested_resource_breadcrumbs(self): def test_nested_resource_breadcrumbs(self):
url = '/resource/123/' url = '/resource/123/'
self.assertEqual(get_breadcrumbs(url), [('Root', '/'), self.assertEqual(
get_breadcrumbs(url),
[
('Root', '/'),
('Resource Root', '/resource/'), ('Resource Root', '/resource/'),
('Resource Instance', '/resource/123'), ('Resource Instance', '/resource/123'),
('Nested Resource Root', '/resource/123/')]) ('Nested Resource Root', '/resource/123/')
]
)
def test_nested_resource_instance_breadcrumbs(self): def test_nested_resource_instance_breadcrumbs(self):
url = '/resource/123/abc' url = '/resource/123/abc'
self.assertEqual(get_breadcrumbs(url), [('Root', '/'), self.assertEqual(
get_breadcrumbs(url),
[
('Root', '/'),
('Resource Root', '/resource/'), ('Resource Root', '/resource/'),
('Resource Instance', '/resource/123'), ('Resource Instance', '/resource/123'),
('Nested Resource Root', '/resource/123/'), ('Nested Resource Root', '/resource/123/'),
('Nested Resource Instance', '/resource/123/abc')]) ('Nested Resource Instance', '/resource/123/abc')
]
)
def test_broken_url_breadcrumbs_handled_gracefully(self): def test_broken_url_breadcrumbs_handled_gracefully(self):
url = '/foobar' url = '/foobar'
self.assertEqual(get_breadcrumbs(url), [('Root', '/')]) self.assertEqual(
get_breadcrumbs(url),
[('Root', '/')]
)

View File

@ -85,11 +85,8 @@ class FileSerializerTests(TestCase):
""" """
Validation should still function when no data dictionary is provided. Validation should still function when no data dictionary is provided.
""" """
now = datetime.datetime.now() uploaded_file = BytesIO(six.b('stuff'))
file = BytesIO(six.b('stuff')) uploaded_file.name = 'stuff.txt'
file.name = 'stuff.txt' uploaded_file.size = len(uploaded_file.getvalue())
file.size = len(file.getvalue()) serializer = UploadedFileSerializer(files={'file': uploaded_file})
uploaded_file = UploadedFile(file=file, created=now)
serializer = UploadedFileSerializer(files={'file': file})
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())

View File

@ -74,7 +74,8 @@ if django_filters:
def get_queryset(self): def get_queryset(self):
return FilterableItem.objects.all() return FilterableItem.objects.all()
urlpatterns = patterns('', urlpatterns = patterns(
'',
url(r'^(?P<pk>\d+)/$', FilterClassDetailView.as_view(), name='detail-view'), url(r'^(?P<pk>\d+)/$', FilterClassDetailView.as_view(), name='detail-view'),
url(r'^$', FilterClassRootView.as_view(), name='root-view'), url(r'^$', FilterClassRootView.as_view(), name='root-view'),
url(r'^get-queryset/$', GetQuerysetView.as_view(), url(r'^get-queryset/$', GetQuerysetView.as_view(),

View File

@ -34,7 +34,8 @@ def not_found(request):
raise Http404() raise Http404()
urlpatterns = patterns('', urlpatterns = patterns(
'',
url(r'^$', example), url(r'^$', example),
url(r'^permission_denied$', permission_denied), url(r'^permission_denied$', permission_denied),
url(r'^not_found$', not_found), url(r'^not_found$', not_found),

View File

@ -94,7 +94,8 @@ class OptionalRelationDetail(generics.RetrieveUpdateDestroyAPIView):
model_serializer_class = serializers.HyperlinkedModelSerializer model_serializer_class = serializers.HyperlinkedModelSerializer
urlpatterns = patterns('', urlpatterns = patterns(
'',
url(r'^basic/$', BasicList.as_view(), name='basicmodel-list'), url(r'^basic/$', BasicList.as_view(), name='basicmodel-list'),
url(r'^basic/(?P<pk>\d+)/$', BasicDetail.as_view(), name='basicmodel-detail'), url(r'^basic/(?P<pk>\d+)/$', BasicDetail.as_view(), name='basicmodel-detail'),
url(r'^anchor/(?P<pk>\d+)/$', AnchorDetail.as_view(), name='anchor-detail'), url(r'^anchor/(?P<pk>\d+)/$', AnchorDetail.as_view(), name='anchor-detail'),

View File

@ -1,7 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import datetime import datetime
from decimal import Decimal from decimal import Decimal
from django.db import models
from django.core.paginator import Paginator from django.core.paginator import Paginator
from django.test import TestCase from django.test import TestCase
from django.utils import unittest from django.utils import unittest
@ -12,6 +11,7 @@ from .models import BasicModel, FilterableItem
factory = APIRequestFactory() factory = APIRequestFactory()
# Helper function to split arguments out of an url # Helper function to split arguments out of an url
def split_arguments_from_url(url): def split_arguments_from_url(url):
if '?' not in url: if '?' not in url:
@ -363,11 +363,11 @@ class TestMaxPaginateByParam(TestCase):
self.assertEqual(response.data['results'], self.data[:3]) self.assertEqual(response.data['results'], self.data[:3])
### Tests for context in pagination serializers # Tests for context in pagination serializers
class CustomField(serializers.Field): class CustomField(serializers.Field):
def to_native(self, value): def to_native(self, value):
if not 'view' in self.context: if 'view' not in self.context:
raise RuntimeError("context isn't getting passed into custom field") raise RuntimeError("context isn't getting passed into custom field")
return "value" return "value"
@ -377,7 +377,7 @@ class BasicModelSerializer(serializers.Serializer):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(BasicModelSerializer, self).__init__(*args, **kwargs) super(BasicModelSerializer, self).__init__(*args, **kwargs)
if not 'view' in self.context: if 'view' not in self.context:
raise RuntimeError("context isn't getting passed into serializer init") raise RuntimeError("context isn't getting passed into serializer init")
@ -398,7 +398,7 @@ class TestContextPassedToCustomField(TestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
### Tests for custom pagination serializers # Tests for custom pagination serializers
class LinksSerializer(serializers.Serializer): class LinksSerializer(serializers.Serializer):
next = pagination.NextPageField(source='*') next = pagination.NextPageField(source='*')
@ -483,8 +483,6 @@ class NonIntegerPaginator(object):
class TestNonIntegerPagination(TestCase): class TestNonIntegerPagination(TestCase):
def test_custom_pagination_serializer(self): def test_custom_pagination_serializer(self):
objects = ['john', 'paul', 'george', 'ringo'] objects = ['john', 'paul', 'george', 'ringo']
paginator = NonIntegerPaginator(objects, 2) paginator = NonIntegerPaginator(objects, 2)

View File

@ -12,6 +12,7 @@ import base64
factory = APIRequestFactory() factory = APIRequestFactory()
class RootView(generics.ListCreateAPIView): class RootView(generics.ListCreateAPIView):
model = BasicModel model = BasicModel
authentication_classes = [authentication.BasicAuthentication] authentication_classes = [authentication.BasicAuthentication]
@ -101,42 +102,54 @@ class ModelPermissionsIntegrationTests(TestCase):
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_options_permitted(self): def test_options_permitted(self):
request = factory.options('/', request = factory.options(
HTTP_AUTHORIZATION=self.permitted_credentials) '/',
HTTP_AUTHORIZATION=self.permitted_credentials
)
response = root_view(request, pk='1') response = root_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn('actions', response.data) self.assertIn('actions', response.data)
self.assertEqual(list(response.data['actions'].keys()), ['POST']) self.assertEqual(list(response.data['actions'].keys()), ['POST'])
request = factory.options('/1', request = factory.options(
HTTP_AUTHORIZATION=self.permitted_credentials) '/1',
HTTP_AUTHORIZATION=self.permitted_credentials
)
response = instance_view(request, pk='1') response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn('actions', response.data) self.assertIn('actions', response.data)
self.assertEqual(list(response.data['actions'].keys()), ['PUT']) self.assertEqual(list(response.data['actions'].keys()), ['PUT'])
def test_options_disallowed(self): def test_options_disallowed(self):
request = factory.options('/', request = factory.options(
HTTP_AUTHORIZATION=self.disallowed_credentials) '/',
HTTP_AUTHORIZATION=self.disallowed_credentials
)
response = root_view(request, pk='1') response = root_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertNotIn('actions', response.data) self.assertNotIn('actions', response.data)
request = factory.options('/1', request = factory.options(
HTTP_AUTHORIZATION=self.disallowed_credentials) '/1',
HTTP_AUTHORIZATION=self.disallowed_credentials
)
response = instance_view(request, pk='1') response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertNotIn('actions', response.data) self.assertNotIn('actions', response.data)
def test_options_updateonly(self): def test_options_updateonly(self):
request = factory.options('/', request = factory.options(
HTTP_AUTHORIZATION=self.updateonly_credentials) '/',
HTTP_AUTHORIZATION=self.updateonly_credentials
)
response = root_view(request, pk='1') response = root_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertNotIn('actions', response.data) self.assertNotIn('actions', response.data)
request = factory.options('/1', request = factory.options(
HTTP_AUTHORIZATION=self.updateonly_credentials) '/1',
HTTP_AUTHORIZATION=self.updateonly_credentials
)
response = instance_view(request, pk='1') response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn('actions', response.data) self.assertIn('actions', response.data)
@ -153,6 +166,7 @@ class BasicPermModel(models.Model):
# add, change, delete built in to django # add, change, delete built in to django
) )
# Custom object-level permission, that includes 'view' permissions # Custom object-level permission, that includes 'view' permissions
class ViewObjectPermissions(permissions.DjangoObjectPermissions): class ViewObjectPermissions(permissions.DjangoObjectPermissions):
perms_map = { perms_map = {
@ -246,21 +260,27 @@ class ObjectPermissionsIntegrationTests(TestCase):
# Update # Update
def test_can_update_permissions(self): def test_can_update_permissions(self):
request = factory.patch('/1', {'text': 'foobar'}, format='json', request = factory.patch(
HTTP_AUTHORIZATION=self.credentials['writeonly']) '/1', {'text': 'foobar'}, format='json',
HTTP_AUTHORIZATION=self.credentials['writeonly']
)
response = object_permissions_view(request, pk='1') response = object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data.get('text'), 'foobar') self.assertEqual(response.data.get('text'), 'foobar')
def test_cannot_update_permissions(self): def test_cannot_update_permissions(self):
request = factory.patch('/1', {'text': 'foobar'}, format='json', request = factory.patch(
HTTP_AUTHORIZATION=self.credentials['deleteonly']) '/1', {'text': 'foobar'}, format='json',
HTTP_AUTHORIZATION=self.credentials['deleteonly']
)
response = object_permissions_view(request, pk='1') response = object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
def test_cannot_update_permissions_non_existing(self): def test_cannot_update_permissions_non_existing(self):
request = factory.patch('/999', {'text': 'foobar'}, format='json', request = factory.patch(
HTTP_AUTHORIZATION=self.credentials['deleteonly']) '/999', {'text': 'foobar'}, format='json',
HTTP_AUTHORIZATION=self.credentials['deleteonly']
)
response = object_permissions_view(request, pk='999') response = object_permissions_view(request, pk='999')
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)

View File

@ -108,19 +108,25 @@ class RelatedFieldSourceTests(TestCase):
doesn't exist. doesn't exist.
""" """
from tests.models import ManyToManySource from tests.models import ManyToManySource
class Meta: class Meta:
model = ManyToManySource model = ManyToManySource
attrs = { attrs = {
'name': serializers.SlugRelatedField( 'name': serializers.SlugRelatedField(
slug_field='name', source='banzai'), slug_field='name', source='banzai'),
'Meta': Meta, 'Meta': Meta,
} }
TestSerializer = type(str('TestSerializer'), TestSerializer = type(
(serializers.ModelSerializer,), attrs) str('TestSerializer'),
(serializers.ModelSerializer,),
attrs
)
with self.assertRaises(AttributeError): with self.assertRaises(AttributeError):
TestSerializer(data={'name': 'foo'}) TestSerializer(data={'name': 'foo'})
@unittest.skipIf(get_version() < '1.6.0', 'Upstream behaviour changed in v1.6') @unittest.skipIf(get_version() < '1.6.0', 'Upstream behaviour changed in v1.6')
class RelatedFieldChoicesTests(TestCase): class RelatedFieldChoicesTests(TestCase):
""" """
@ -141,4 +147,3 @@ class RelatedFieldChoicesTests(TestCase):
widget_count = len(field.widget.choices) widget_count = len(field.widget.choices)
self.assertEqual(widget_count, choice_count + 1, 'BLANK_CHOICE_DASH option should have been added') self.assertEqual(widget_count, choice_count + 1, 'BLANK_CHOICE_DASH option should have been added')

View File

@ -16,7 +16,8 @@ request = factory.get('/') # Just to ensure we have a request in the serializer
def dummy_view(request, pk): def dummy_view(request, pk):
pass pass
urlpatterns = patterns('', urlpatterns = patterns(
'',
url(r'^dummyurl/(?P<pk>[0-9]+)/$', dummy_view, name='dummy-url'), url(r'^dummyurl/(?P<pk>[0-9]+)/$', dummy_view, name='dummy-url'),
url(r'^manytomanysource/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanysource-detail'), url(r'^manytomanysource/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanysource-detail'),
url(r'^manytomanytarget/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanytarget-detail'), url(r'^manytomanytarget/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanytarget-detail'),

View File

@ -76,7 +76,6 @@ class MockGETView(APIView):
return Response({'foo': ['bar', 'baz']}) return Response({'foo': ['bar', 'baz']})
class MockPOSTView(APIView): class MockPOSTView(APIView):
def post(self, request, **kwargs): def post(self, request, **kwargs):
return Response({'foo': request.DATA}) return Response({'foo': request.DATA})
@ -102,7 +101,8 @@ class HTMLView1(APIView):
def get(self, request, **kwargs): def get(self, request, **kwargs):
return Response('text') return Response('text')
urlpatterns = patterns('', urlpatterns = patterns(
'',
url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB])), url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB])),
url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB])), url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB])),
url(r'^cache$', MockGETView.as_view()), url(r'^cache$', MockGETView.as_view()),
@ -312,16 +312,22 @@ class JSONRendererTests(TestCase):
class Dict(MutableMapping): class Dict(MutableMapping):
def __init__(self): def __init__(self):
self._dict = dict() self._dict = dict()
def __getitem__(self, key): def __getitem__(self, key):
return self._dict.__getitem__(key) return self._dict.__getitem__(key)
def __setitem__(self, key, value): def __setitem__(self, key, value):
return self._dict.__setitem__(key, value) return self._dict.__setitem__(key, value)
def __delitem__(self, key): def __delitem__(self, key):
return self._dict.__delitem__(key) return self._dict.__delitem__(key)
def __iter__(self): def __iter__(self):
return self._dict.__iter__() return self._dict.__iter__()
def __len__(self): def __len__(self):
return self._dict.__len__() return self._dict.__len__()
def keys(self): def keys(self):
return self._dict.keys() return self._dict.keys()
@ -336,8 +342,10 @@ class JSONRendererTests(TestCase):
class DictLike(object): class DictLike(object):
def __init__(self): def __init__(self):
self._dict = {} self._dict = {}
def set(self, value): def set(self, value):
self._dict = dict(value) self._dict = dict(value)
def __getitem__(self, key): def __getitem__(self, key):
return self._dict[key] return self._dict[key]
@ -394,35 +402,47 @@ class JSONPRendererTests(TestCase):
""" """
Test JSONP rendering with View JSON Renderer. Test JSONP rendering with View JSON Renderer.
""" """
resp = self.client.get('/jsonp/jsonrenderer', resp = self.client.get(
HTTP_ACCEPT='application/javascript') '/jsonp/jsonrenderer',
HTTP_ACCEPT='application/javascript'
)
self.assertEqual(resp.status_code, status.HTTP_200_OK) self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8') self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8')
self.assertEqual(resp.content, self.assertEqual(
('callback(%s);' % _flat_repr).encode('ascii')) resp.content,
('callback(%s);' % _flat_repr).encode('ascii')
)
def test_without_callback_without_json_renderer(self): def test_without_callback_without_json_renderer(self):
""" """
Test JSONP rendering without View JSON Renderer. Test JSONP rendering without View JSON Renderer.
""" """
resp = self.client.get('/jsonp/nojsonrenderer', resp = self.client.get(
HTTP_ACCEPT='application/javascript') '/jsonp/nojsonrenderer',
HTTP_ACCEPT='application/javascript'
)
self.assertEqual(resp.status_code, status.HTTP_200_OK) self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8') self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8')
self.assertEqual(resp.content, self.assertEqual(
('callback(%s);' % _flat_repr).encode('ascii')) resp.content,
('callback(%s);' % _flat_repr).encode('ascii')
)
def test_with_callback(self): def test_with_callback(self):
""" """
Test JSONP rendering with callback function name. Test JSONP rendering with callback function name.
""" """
callback_func = 'myjsonpcallback' callback_func = 'myjsonpcallback'
resp = self.client.get('/jsonp/nojsonrenderer?callback=' + callback_func, resp = self.client.get(
HTTP_ACCEPT='application/javascript') '/jsonp/nojsonrenderer?callback=' + callback_func,
HTTP_ACCEPT='application/javascript'
)
self.assertEqual(resp.status_code, status.HTTP_200_OK) self.assertEqual(resp.status_code, status.HTTP_200_OK)
self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8') self.assertEqual(resp['Content-Type'], 'application/javascript; charset=utf-8')
self.assertEqual(resp.content, self.assertEqual(
('%s(%s);' % (callback_func, _flat_repr)).encode('ascii')) resp.content,
('%s(%s);' % (callback_func, _flat_repr)).encode('ascii')
)
if yaml: if yaml:
@ -467,7 +487,6 @@ if yaml:
def assertYAMLContains(self, content, string): def assertYAMLContains(self, content, string):
self.assertTrue(string in content, '%r not in %r' % (string, content)) self.assertTrue(string in content, '%r not in %r' % (string, content))
class UnicodeYAMLRendererTests(TestCase): class UnicodeYAMLRendererTests(TestCase):
""" """
Tests specific for the Unicode YAML Renderer Tests specific for the Unicode YAML Renderer
@ -592,13 +611,13 @@ class CacheRenderTest(TestCase):
""" Return any errors that would be raised if `obj' is pickled """ Return any errors that would be raised if `obj' is pickled
Courtesy of koffie @ http://stackoverflow.com/a/7218986/109897 Courtesy of koffie @ http://stackoverflow.com/a/7218986/109897
""" """
if seen == None: if seen is None:
seen = [] seen = []
try: try:
state = obj.__getstate__() state = obj.__getstate__()
except AttributeError: except AttributeError:
return return
if state == None: if state is None:
return return
if isinstance(state, tuple): if isinstance(state, tuple):
if not isinstance(state[0], dict): if not isinstance(state[0], dict):

View File

@ -272,7 +272,8 @@ class MockView(APIView):
return Response(status=status.INTERNAL_SERVER_ERROR) return Response(status=status.INTERNAL_SERVER_ERROR)
urlpatterns = patterns('', urlpatterns = patterns(
'',
(r'^$', MockView.as_view()), (r'^$', MockView.as_view()),
) )

View File

@ -100,7 +100,8 @@ new_model_viewset_router = routers.DefaultRouter()
new_model_viewset_router.register(r'', HTMLNewModelViewSet) new_model_viewset_router.register(r'', HTMLNewModelViewSet)
urlpatterns = patterns('', urlpatterns = patterns(
'',
url(r'^setbyview$', MockViewSettingContentType.as_view(renderer_classes=[RendererA, RendererB, RendererC])), url(r'^setbyview$', MockViewSettingContentType.as_view(renderer_classes=[RendererA, RendererB, RendererC])),
url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])), url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])),
url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])), url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])),

View File

@ -10,7 +10,8 @@ factory = APIRequestFactory()
def null_view(request): def null_view(request):
pass pass
urlpatterns = patterns('', urlpatterns = patterns(
'',
url(r'^view$', null_view, name='view'), url(r'^view$', null_view, name='view'),
) )

View File

@ -93,7 +93,8 @@ class TestCustomLookupFields(TestCase):
from tests import test_routers from tests import test_routers
urls = getattr(test_routers, 'urlpatterns') urls = getattr(test_routers, 'urlpatterns')
urls += patterns('', urls += patterns(
'',
url(r'^', include(self.router.urls)), url(r'^', include(self.router.urls)),
) )
@ -104,7 +105,8 @@ class TestCustomLookupFields(TestCase):
def test_retrieve_lookup_field_list_view(self): def test_retrieve_lookup_field_list_view(self):
response = self.client.get('/notes/') response = self.client.get('/notes/')
self.assertEqual(response.data, self.assertEqual(
response.data,
[{ [{
"url": "http://testserver/notes/123/", "url": "http://testserver/notes/123/",
"uuid": "123", "text": "foo bar" "uuid": "123", "text": "foo bar"
@ -113,7 +115,8 @@ class TestCustomLookupFields(TestCase):
def test_retrieve_lookup_field_detail_view(self): def test_retrieve_lookup_field_detail_view(self):
response = self.client.get('/notes/123/') response = self.client.get('/notes/123/')
self.assertEqual(response.data, self.assertEqual(
response.data,
{ {
"url": "http://testserver/notes/123/", "url": "http://testserver/notes/123/",
"uuid": "123", "text": "foo bar" "uuid": "123", "text": "foo bar"

View File

@ -7,10 +7,12 @@ from django.utils import unittest
from django.utils.datastructures import MultiValueDict from django.utils.datastructures import MultiValueDict
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework import serializers, fields, relations from rest_framework import serializers, fields, relations
from tests.models import (HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel, from tests.models import (
BlankFieldModel, BlogPost, BlogPostComment, Book, CallableDefaultValueModel, DefaultValueModel, HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel,
ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo, RESTFrameworkModel, BlankFieldModel, BlogPost, BlogPostComment, Book, CallableDefaultValueModel,
ForeignKeySource, ManyToManySource) DefaultValueModel, ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo,
RESTFrameworkModel, ForeignKeySource
)
from tests.models import BasicModelSerializer from tests.models import BasicModelSerializer
import datetime import datetime
import pickle import pickle
@ -99,6 +101,7 @@ class ActionItemSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = ActionItem model = ActionItem
class ActionItemSerializerOptionalFields(serializers.ModelSerializer): class ActionItemSerializerOptionalFields(serializers.ModelSerializer):
""" """
Intended to test that fields with `required=False` are excluded from validation. Intended to test that fields with `required=False` are excluded from validation.
@ -109,6 +112,7 @@ class ActionItemSerializerOptionalFields(serializers.ModelSerializer):
model = ActionItem model = ActionItem
fields = ('title',) fields = ('title',)
class ActionItemSerializerCustomRestore(serializers.ModelSerializer): class ActionItemSerializerCustomRestore(serializers.ModelSerializer):
class Meta: class Meta:
@ -295,8 +299,10 @@ class BasicTests(TestCase):
in the Meta data in the Meta data
""" """
serializer = PersonSerializer(self.person) serializer = PersonSerializer(self.person)
self.assertEqual(set(serializer.data.keys()), self.assertEqual(
set(['name', 'age', 'info'])) set(serializer.data.keys()),
set(['name', 'age', 'info'])
)
def test_field_with_dictionary(self): def test_field_with_dictionary(self):
""" """
@ -660,7 +666,7 @@ class ModelValidationTests(TestCase):
serializer.save() serializer.save()
second_serializer = AlbumsSerializer(data={'title': 'a'}) second_serializer = AlbumsSerializer(data={'title': 'a'})
self.assertFalse(second_serializer.is_valid()) self.assertFalse(second_serializer.is_valid())
self.assertEqual(second_serializer.errors, {'title': ['Album with this Title already exists.'],}) self.assertEqual(second_serializer.errors, {'title': ['Album with this Title already exists.']})
third_serializer = AlbumsSerializer(data=[{'title': 'b', 'ref': '1'}, {'title': 'c'}], many=True) third_serializer = AlbumsSerializer(data=[{'title': 'b', 'ref': '1'}, {'title': 'c'}], many=True)
self.assertFalse(third_serializer.is_valid()) self.assertFalse(third_serializer.is_valid())
self.assertEqual(third_serializer.errors, [{'ref': ['Album with this Ref already exists.']}, {}]) self.assertEqual(third_serializer.errors, [{'ref': ['Album with this Ref already exists.']}, {}])
@ -1276,7 +1282,7 @@ class BlankFieldTests(TestCase):
self.fail('Exception raised on save() after validation passes') self.fail('Exception raised on save() after validation passes')
#test for issue #460 # Test for issue #460
class SerializerPickleTests(TestCase): class SerializerPickleTests(TestCase):
""" """
Test pickleability of the output of Serializers Test pickleability of the output of Serializers
@ -1500,7 +1506,7 @@ class NestedSerializerContextTests(TestCase):
callable = serializers.SerializerMethodField('_callable') callable = serializers.SerializerMethodField('_callable')
def _callable(self, instance): def _callable(self, instance):
if not 'context_item' in self.context: if 'context_item' not in self.context:
raise RuntimeError("context isn't getting passed into 2nd level nested serializer") raise RuntimeError("context isn't getting passed into 2nd level nested serializer")
return "success" return "success"
@ -1513,7 +1519,7 @@ class NestedSerializerContextTests(TestCase):
callable = serializers.SerializerMethodField("_callable") callable = serializers.SerializerMethodField("_callable")
def _callable(self, instance): def _callable(self, instance):
if not 'context_item' in self.context: if 'context_item' not in self.context:
raise RuntimeError("context isn't getting passed into 1st level nested serializer") raise RuntimeError("context isn't getting passed into 1st level nested serializer")
return "success" return "success"
@ -1816,7 +1822,7 @@ class MetadataSerializerTestCase(TestCase):
self.assertEqual(expected, metadata) self.assertEqual(expected, metadata)
### Regression test for #840 # Regression test for #840
class SimpleModel(models.Model): class SimpleModel(models.Model):
text = models.CharField(max_length=100) text = models.CharField(max_length=100)
@ -1850,7 +1856,7 @@ class FieldValidationRemovingAttr(TestCase):
self.assertEqual(serializer.object.text, 'foo') self.assertEqual(serializer.object.text, 'foo')
### Regression test for #878 # Regression test for #878
class SimpleTargetModel(models.Model): class SimpleTargetModel(models.Model):
text = models.CharField(max_length=100) text = models.CharField(max_length=100)

View File

@ -328,12 +328,14 @@ class NestedModelSerializerUpdateTests(TestCase):
class BlogPostSerializer(serializers.ModelSerializer): class BlogPostSerializer(serializers.ModelSerializer):
comments = BlogPostCommentSerializer(many=True, source='blogpostcomment_set') comments = BlogPostCommentSerializer(many=True, source='blogpostcomment_set')
class Meta: class Meta:
model = models.BlogPost model = models.BlogPost
fields = ('id', 'title', 'comments') fields = ('id', 'title', 'comments')
class PersonSerializer(serializers.ModelSerializer): class PersonSerializer(serializers.ModelSerializer):
posts = BlogPostSerializer(many=True, source='blogpost_set') posts = BlogPostSerializer(many=True, source='blogpost_set')
class Meta: class Meta:
model = models.Person model = models.Person
fields = ('id', 'name', 'age', 'posts') fields = ('id', 'name', 'age', 'posts')

View File

@ -1,9 +1,7 @@
from django.db import models
from django.test import TestCase from django.test import TestCase
from rest_framework.compat import six
from rest_framework.serializers import _resolve_model from rest_framework.serializers import _resolve_model
from tests.models import BasicModel from tests.models import BasicModel
from rest_framework.compat import six
class ResolveModelTests(TestCase): class ResolveModelTests(TestCase):

View File

@ -11,7 +11,7 @@ class TemplateTagTests(TestCase):
def test_add_query_param_with_non_latin_charactor(self): def test_add_query_param_with_non_latin_charactor(self):
# Ensure we don't double-escape non-latin characters # Ensure we don't double-escape non-latin characters
# that are present in the querystring. # that are present in the querystring.
# See #1314. # See #1314.
request = factory.get("/", {'q': '查询'}) request = factory.get("/", {'q': '查询'})
json_url = add_query_param(request, "format", "json") json_url = add_query_param(request, "format", "json")
@ -48,4 +48,4 @@ class Issue1386Tests(TestCase):
self.assertEqual(i, res) self.assertEqual(i, res)
# example from issue #1386, this shouldn't raise an exception # example from issue #1386, this shouldn't raise an exception
_ = urlize_quoted_links("asdf:[/p]zxcv.com") urlize_quoted_links("asdf:[/p]zxcv.com")

View File

@ -28,7 +28,8 @@ def session_view(request):
}) })
urlpatterns = patterns('', urlpatterns = patterns(
'',
url(r'^view/$', view), url(r'^view/$', view),
url(r'^session-view/$', session_view), url(r'^session-view/$', session_view),
) )
@ -142,7 +143,8 @@ class TestAPIRequestFactory(TestCase):
assertion error. assertion error.
""" """
factory = APIRequestFactory() factory = APIRequestFactory()
self.assertRaises(AssertionError, factory.post, self.assertRaises(
AssertionError, factory.post,
path='/view/', data={'example': 1}, format='xml' path='/view/', data={'example': 1}, format='xml'
) )

View File

@ -125,36 +125,42 @@ class ThrottlingTests(TestCase):
""" """
Ensure for second based throttles. Ensure for second based throttles.
""" """
self.ensure_response_header_contains_proper_throttle_field(MockView, self.ensure_response_header_contains_proper_throttle_field(
((0, None), MockView, (
(0, None),
(0, None), (0, None),
(0, None), (0, None),
(0, '1') (0, '1')
)) )
)
def test_minutes_fields(self): def test_minutes_fields(self):
""" """
Ensure for minute based throttles. Ensure for minute based throttles.
""" """
self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling, self.ensure_response_header_contains_proper_throttle_field(
((0, None), MockView_MinuteThrottling, (
(0, None),
(0, None), (0, None),
(0, None), (0, None),
(0, '60') (0, '60')
)) )
)
def test_next_rate_remains_constant_if_followed(self): def test_next_rate_remains_constant_if_followed(self):
""" """
If a client follows the recommended next request rate, If a client follows the recommended next request rate,
the throttling rate should stay constant. the throttling rate should stay constant.
""" """
self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling, self.ensure_response_header_contains_proper_throttle_field(
((0, None), MockView_MinuteThrottling, (
(0, None),
(20, None), (20, None),
(40, None), (40, None),
(60, None), (60, None),
(80, None) (80, None)
)) )
)
def test_non_time_throttle(self): def test_non_time_throttle(self):
""" """

View File

@ -1,7 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from django.test import TestCase from django.test import TestCase
from rest_framework.templatetags.rest_framework import urlize_quoted_links from rest_framework.templatetags.rest_framework import urlize_quoted_links
import sys
class URLizerTests(TestCase): class URLizerTests(TestCase):

11
tox.ini
View File

@ -1,13 +1,20 @@
[tox] [tox]
downloadcache = {toxworkdir}/cache/ downloadcache = {toxworkdir}/cache/
envlist = envlist =
flake8,
py3.4-django1.7,py3.3-django1.7,py3.2-django1.7,py2.7-django1.7, py3.4-django1.7,py3.3-django1.7,py3.2-django1.7,py2.7-django1.7,
py3.4-django1.6,py3.3-django1.6,py3.2-django1.6,py2.7-django1.6,py2.6-django1.6, py3.4-django1.6,py3.3-django1.6,py3.2-django1.6,py2.7-django1.6,py2.6-django1.6,
py3.4-django1.5,py3.3-django1.5,py3.2-django1.5,py2.7-django1.5,py2.6-django1.5, py3.4-django1.5,py3.3-django1.5,py3.2-django1.5,py2.7-django1.5,py2.6-django1.5,
py2.7-django1.4,py2.6-django1.4, py2.7-django1.4,py2.6-django1.4
[testenv] [testenv]
commands = py.test -q commands = ./runtests.py --fast
[testenv:flake8]
basepython = python2.7
deps = pytest==2.5.2
flake8==2.2.2
commands = ./runtests.py --lintonly
[testenv:py3.4-django1.7] [testenv:py3.4-django1.7]
basepython = python3.4 basepython = python3.4