Travis will now check if django rest framework conforms to pep8 before running the tests. Added autopep8 to the development process. Applied it for the first time.

This commit is contained in:
Omer Katz 2013-01-25 19:46:37 +03:00
parent e92a01224d
commit 862cafa81a
58 changed files with 763 additions and 458 deletions

View File

@ -14,6 +14,8 @@ install:
- pip install django-filter==0.5.4 --use-mirrors - pip install django-filter==0.5.4 --use-mirrors
- pip install -r development.txt - pip install -r development.txt
- export PYTHONPATH=. - export PYTHONPATH=.
before-script:
- pep8 --exclude=migrations --ignore="E501,E255,E261,W191,E101" .
script: script:
- python rest_framework/runtests/runtests.py - python rest_framework/runtests/runtests.py
- python rest_framework/runtests/runcoverage.py - python rest_framework/runtests/runcoverage.py

View File

@ -80,6 +80,10 @@ To start hacking type.
pip install -r development.txt pip install -r development.txt
Before pushing run.
autopep8 -r --in-place ./rest_framework/
To run the tests. To run the tests.
./rest_framework/runtests/runtests.py ./rest_framework/runtests/runtests.py

View File

@ -1,3 +1,4 @@
autopep>=0.8.5
pep>=1.4.1 pep>=1.4.1
coverage>=3.6 coverage>=3.6
django-discover-runner>=0.2.2 django-discover-runner>=0.2.2

View File

@ -7,7 +7,7 @@ from django.db import models
try: try:
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
except ImportError: # django < 1.5 except ImportError: # django < 1.5
from django.contrib.auth.models import User from django.contrib.auth.models import User
else: else:
User = get_user_model() User = get_user_model()
@ -18,18 +18,18 @@ class Migration(SchemaMigration):
def forwards(self, orm): def forwards(self, orm):
# Adding model 'Token' # Adding model 'Token'
db.create_table('authtoken_token', ( db.create_table('authtoken_token', (
('key', self.gf('django.db.models.fields.CharField')(max_length=40, primary_key=True)), ('key', self.gf('django.db.models.fields.CharField')
(max_length=40, primary_key=True)),
('user', self.gf('django.db.models.fields.related.OneToOneField')(related_name='auth_token', unique=True, to=orm['%s.%s' % (User._meta.app_label, User._meta.object_name)])), ('user', self.gf('django.db.models.fields.related.OneToOneField')(related_name='auth_token', unique=True, to=orm['%s.%s' % (User._meta.app_label, User._meta.object_name)])),
('created', self.gf('django.db.models.fields.DateTimeField')(auto_now_add=True, blank=True)), ('created', self.gf('django.db.models.fields.DateTimeField')
(auto_now_add=True, blank=True)),
)) ))
db.send_create_signal('authtoken', ['Token']) db.send_create_signal('authtoken', ['Token'])
def backwards(self, orm): def backwards(self, orm):
# Deleting model 'Token' # Deleting model 'Token'
db.delete_table('authtoken_token') db.delete_table('authtoken_token')
models = { models = {
'auth.group': { 'auth.group': {
'Meta': {'object_name': 'Group'}, 'Meta': {'object_name': 'Group'},

View File

@ -15,10 +15,13 @@ class AuthTokenSerializer(serializers.Serializer):
if user: if user:
if not user.is_active: if not user.is_active:
raise serializers.ValidationError('User account is disabled.') raise serializers.ValidationError(
'User account is disabled.')
attrs['user'] = user attrs['user'] = user
return attrs return attrs
else: else:
raise serializers.ValidationError('Unable to login with provided credentials.') raise serializers.ValidationError(
'Unable to login with provided credentials.')
else: else:
raise serializers.ValidationError('Must include "username" and "password"') raise serializers.ValidationError(
'Must include "username" and "password"')

View File

@ -10,7 +10,8 @@ from rest_framework.authtoken.serializers import AuthTokenSerializer
class ObtainAuthToken(APIView): class ObtainAuthToken(APIView):
throttle_classes = () throttle_classes = ()
permission_classes = () permission_classes = ()
parser_classes = (parsers.FormParser, parsers.MultiPartParser, parsers.JSONParser,) parser_classes = (
parsers.FormParser, parsers.MultiPartParser, parsers.JSONParser,)
renderer_classes = (renderers.JSONRenderer,) renderer_classes = (renderers.JSONRenderer,)
serializer_class = AuthTokenSerializer serializer_class = AuthTokenSerializer
model = Token model = Token
@ -18,7 +19,8 @@ class ObtainAuthToken(APIView):
def post(self, request): def post(self, request):
serializer = self.serializer_class(data=request.DATA) serializer = self.serializer_class(data=request.DATA)
if serializer.is_valid(): if serializer.is_valid():
token, created = Token.objects.get_or_create(user=serializer.object['user']) token, created = Token.objects.get_or_create(
user=serializer.object['user'])
return Response({'token': token.key}) return Response({'token': token.key})
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)

View File

@ -100,7 +100,8 @@ else:
# https://github.com/markotibold/django-rest-framework/tree/patch # https://github.com/markotibold/django-rest-framework/tree/patch
http_method_names = set(View.http_method_names) http_method_names = set(View.http_method_names)
http_method_names.add('patch') http_method_names.add('patch')
View.http_method_names = list(http_method_names) # PATCH method is not implemented by Django View.http_method_names = list(
http_method_names) # PATCH method is not implemented by Django
# PUT, DELETE do not require CSRF until 1.4. They should. Make it better. # PUT, DELETE do not require CSRF until 1.4. They should. Make it better.
if django.VERSION >= (1, 4): if django.VERSION >= (1, 4):
@ -184,7 +185,8 @@ else:
def _sanitize_token(token): def _sanitize_token(token):
# Allow only alphanum, and ensure we return a 'str' for the sake of the post # Allow only alphanum, and ensure we return a 'str' for the sake of the post
# processing middleware. # processing middleware.
token = re.sub('[^a-zA-Z0-9]', '', str(token.decode('ascii', 'ignore'))) token = re.sub(
'[^a-zA-Z0-9]', '', str(token.decode('ascii', 'ignore')))
if token == "": if token == "":
# In case the cookie has been truncated to nothing at some point. # In case the cookie has been truncated to nothing at some point.
return _get_new_csrf_key() return _get_new_csrf_key()
@ -218,12 +220,14 @@ else:
return None return None
try: try:
csrf_token = _sanitize_token(request.COOKIES[settings.CSRF_COOKIE_NAME]) csrf_token = _sanitize_token(
request.COOKIES[settings.CSRF_COOKIE_NAME])
# Use same token next time # Use same token next time
request.META['CSRF_COOKIE'] = csrf_token request.META['CSRF_COOKIE'] = csrf_token
except KeyError: except KeyError:
csrf_token = None csrf_token = None
# Generate token and store it in the request, so it's available to the view. # Generate token and store it in the request, so it's available
# to the view.
request.META["CSRF_COOKIE"] = _get_new_csrf_key() request.META["CSRF_COOKIE"] = _get_new_csrf_key()
# Wait until request.META["CSRF_COOKIE"] has been manipulated before # Wait until request.META["CSRF_COOKIE"] has been manipulated before
@ -231,7 +235,8 @@ else:
if getattr(callback, 'csrf_exempt', False): if getattr(callback, 'csrf_exempt', False):
return None return None
# Assume that anything not defined as 'safe' by RC2616 needs protection. # Assume that anything not defined as 'safe' by RC2616 needs
# protection.
if request.method not in ('GET', 'HEAD', 'OPTIONS', 'TRACE'): if request.method not in ('GET', 'HEAD', 'OPTIONS', 'TRACE'):
if getattr(request, '_dont_enforce_csrf_checks', False): if getattr(request, '_dont_enforce_csrf_checks', False):
# Mechanism to turn off CSRF checks for test suite. It comes after # Mechanism to turn off CSRF checks for test suite. It comes after
@ -258,7 +263,9 @@ else:
# we can use strict Referer checking. # we can use strict Referer checking.
referer = request.META.get('HTTP_REFERER') referer = request.META.get('HTTP_REFERER')
if referer is None: if referer is None:
logger.warning('Forbidden (%s): %s' % (REASON_NO_REFERER, request.path), logger.warning(
'Forbidden (%s): %s' % (
REASON_NO_REFERER, request.path),
extra={ extra={
'status_code': 403, 'status_code': 403,
'request': request, 'request': request,
@ -270,7 +277,8 @@ else:
good_referer = 'https://%s/' % request.get_host() good_referer = 'https://%s/' % request.get_host()
if not same_origin(referer, good_referer): if not same_origin(referer, good_referer):
reason = REASON_BAD_REFERER % (referer, good_referer) reason = REASON_BAD_REFERER % (referer, good_referer)
logger.warning('Forbidden (%s): %s' % (reason, request.path), logger.warning(
'Forbidden (%s): %s' % (reason, request.path),
extra={ extra={
'status_code': 403, 'status_code': 403,
'request': request, 'request': request,
@ -282,7 +290,9 @@ else:
# No CSRF cookie. For POST requests, we insist on a CSRF cookie, # No CSRF cookie. For POST requests, we insist on a CSRF cookie,
# and in this way we can avoid all CSRF attacks, including login # and in this way we can avoid all CSRF attacks, including login
# CSRF. # CSRF.
logger.warning('Forbidden (%s): %s' % (REASON_NO_CSRF_COOKIE, request.path), logger.warning(
'Forbidden (%s): %s' % (
REASON_NO_CSRF_COOKIE, request.path),
extra={ extra={
'status_code': 403, 'status_code': 403,
'request': request, 'request': request,
@ -293,15 +303,19 @@ else:
# check non-cookie token for match # check non-cookie token for match
request_csrf_token = "" request_csrf_token = ""
if request.method == "POST": if request.method == "POST":
request_csrf_token = request.POST.get('csrfmiddlewaretoken', '') request_csrf_token = request.POST.get(
'csrfmiddlewaretoken', '')
if request_csrf_token == "": if request_csrf_token == "":
# Fall back to X-CSRFToken, to make things easier for AJAX, # Fall back to X-CSRFToken, to make things easier for AJAX,
# and possible for PUT/DELETE # and possible for PUT/DELETE
request_csrf_token = request.META.get('HTTP_X_CSRFTOKEN', '') request_csrf_token = request.META.get(
'HTTP_X_CSRFTOKEN', '')
if not constant_time_compare(request_csrf_token, csrf_token): if not constant_time_compare(request_csrf_token, csrf_token):
logger.warning('Forbidden (%s): %s' % (REASON_BAD_TOKEN, request.path), logger.warning(
'Forbidden (%s): %s' % (
REASON_BAD_TOKEN, request.path),
extra={ extra={
'status_code': 403, 'status_code': 403,
'request': request, 'request': request,

View File

@ -30,10 +30,12 @@ def api_view(http_method_names):
# api_view applied with eg. string instead of list of strings # api_view applied with eg. string instead of list of strings
assert isinstance(http_method_names, (list, tuple)), \ assert isinstance(http_method_names, (list, tuple)), \
'@api_view expected a list of strings, recieved %s' % type(http_method_names).__name__ '@api_view expected a list of strings, recieved %s' % type(
http_method_names).__name__
allowed_methods = set(http_method_names) | set(('options',)) allowed_methods = set(http_method_names) | set(('options',))
WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods] WrappedAPIView.http_method_names = [method.lower(
) for method in allowed_methods]
def handler(self, *args, **kwargs): def handler(self, *args, **kwargs):
return func(*args, **kwargs) return func(*args, **kwargs)
@ -49,8 +51,9 @@ def api_view(http_method_names):
WrappedAPIView.parser_classes = getattr(func, 'parser_classes', WrappedAPIView.parser_classes = getattr(func, 'parser_classes',
APIView.parser_classes) APIView.parser_classes)
WrappedAPIView.authentication_classes = getattr(func, 'authentication_classes', WrappedAPIView.authentication_classes = getattr(
APIView.authentication_classes) func, 'authentication_classes',
APIView.authentication_classes)
WrappedAPIView.throttle_classes = getattr(func, 'throttle_classes', WrappedAPIView.throttle_classes = getattr(func, 'throttle_classes',
APIView.throttle_classes) APIView.throttle_classes)

View File

@ -221,16 +221,18 @@ class ModelField(WritableField):
raise ValueError("ModelField requires 'model_field' kwarg") raise ValueError("ModelField requires 'model_field' kwarg")
self.min_length = kwargs.pop('min_length', self.min_length = kwargs.pop('min_length',
getattr(self.model_field, 'min_length', None)) getattr(self.model_field, 'min_length', None))
self.max_length = kwargs.pop('max_length', self.max_length = kwargs.pop('max_length',
getattr(self.model_field, 'max_length', None)) getattr(self.model_field, 'max_length', None))
super(ModelField, self).__init__(*args, **kwargs) super(ModelField, self).__init__(*args, **kwargs)
if self.min_length is not None: if self.min_length is not None:
self.validators.append(validators.MinLengthValidator(self.min_length)) self.validators.append(
validators.MinLengthValidator(self.min_length))
if self.max_length is not None: if self.max_length is not None:
self.validators.append(validators.MaxLengthValidator(self.max_length)) self.validators.append(
validators.MaxLengthValidator(self.max_length))
def from_native(self, value): def from_native(self, value):
rel = getattr(self.model_field, "rel", None) rel = getattr(self.model_field, "rel", None)
@ -349,7 +351,8 @@ class ChoiceField(WritableField):
""" """
super(ChoiceField, self).validate(value) super(ChoiceField, self).validate(value)
if value and not self.valid_value(value): if value and not self.valid_value(value):
raise ValidationError(self.error_messages['invalid_choice'] % {'value': value}) raise ValidationError(
self.error_messages['invalid_choice'] % {'value': value})
def valid_value(self, value): def valid_value(self, value):
""" """
@ -385,7 +388,7 @@ class EmailField(CharField):
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
result = copy.copy(self) result = copy.copy(self)
memo[id(self)] = result memo[id(self)] = result
#result.widget = copy.deepcopy(self.widget, memo) # result.widget = copy.deepcopy(self.widget, memo)
result.validators = self.validators[:] result.validators = self.validators[:]
return result return result
@ -395,7 +398,8 @@ class RegexField(CharField):
form_field_class = forms.RegexField form_field_class = forms.RegexField
def __init__(self, regex, max_length=None, min_length=None, *args, **kwargs): def __init__(self, regex, max_length=None, min_length=None, *args, **kwargs):
super(RegexField, self).__init__(max_length, min_length, *args, **kwargs) super(RegexField, self).__init__(max_length, min_length, *
args, **kwargs)
self.regex = regex self.regex = regex
def _get_regex(self): def _get_regex(self):
@ -595,7 +599,8 @@ class FileField(WritableField):
if self.max_length is not None and len(file_name) > self.max_length: if self.max_length is not None and len(file_name) > self.max_length:
error_values = {'max': self.max_length, 'length': len(file_name)} error_values = {'max': self.max_length, 'length': len(file_name)}
raise ValidationError(self.error_messages['max_length'] % error_values) raise ValidationError(
self.error_messages['max_length'] % error_values)
if not file_name: if not file_name:
raise ValidationError(self.error_messages['invalid']) raise ValidationError(self.error_messages['invalid'])
if not self.allow_empty_file and not file_size: if not self.allow_empty_file and not file_size:

View File

@ -15,7 +15,8 @@ class CreateModelMixin(object):
Should be mixed in with any `BaseView`. Should be mixed in with any `BaseView`.
""" """
def create(self, request, *args, **kwargs): def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.DATA, files=request.FILES) serializer = self.get_serializer(
data=request.DATA, files=request.FILES)
if serializer.is_valid(): if serializer.is_valid():
self.pre_save(serializer.object) self.pre_save(serializer.object)

View File

@ -43,20 +43,22 @@ class DefaultContentNegotiation(BaseContentNegotiation):
# Check the acceptable media types against each renderer, # Check the acceptable media types against each renderer,
# attempting more specific media types first # attempting more specific media types first
# NB. The inner loop here isn't as bad as it first looks :) # NB. The inner loop here isn't as bad as it first looks :)
# Worst case is we're looping over len(accept_list) * len(self.renderers) # Worst case is we're looping over len(accept_list) *
# len(self.renderers)
for media_type_set in order_by_precedence(accepts): for media_type_set in order_by_precedence(accepts):
for renderer in renderers: for renderer in renderers:
for media_type in media_type_set: for media_type in media_type_set:
if media_type_matches(renderer.media_type, media_type): if media_type_matches(renderer.media_type, media_type):
# Return the most specific media type as accepted. # Return the most specific media type as accepted.
if (_MediaType(renderer.media_type).precedence > if (_MediaType(renderer.media_type).precedence >
_MediaType(media_type).precedence): _MediaType(media_type).precedence):
# Eg client requests '*/*' # Eg client requests '*/*'
# Accepted media type is 'application/json' # Accepted media type is 'application/json'
return renderer, renderer.media_type return renderer, renderer.media_type
else: else:
# Eg client requests 'application/json; indent=8' # Eg client requests 'application/json; indent=8'
# Accepted media type is 'application/json; indent=8' # Accepted media type is 'application/json;
# indent=8'
return renderer, media_type return renderer, media_type
raise exceptions.NotAcceptable(available_renderers=renderers) raise exceptions.NotAcceptable(available_renderers=renderers)

View File

@ -68,7 +68,8 @@ class BasePaginationSerializer(serializers.Serializer):
else: else:
context_kwarg = {} context_kwarg = {}
self.fields[results_field] = object_serializer(source='object_list', **context_kwarg) self.fields[results_field] = object_serializer(
source='object_list', **context_kwarg)
def to_native(self, obj): def to_native(self, obj):
""" """

View File

@ -151,7 +151,8 @@ class XMLParser(BaseParser):
if len(children) == 0: if len(children) == 0:
return self._type_convert(element.text) return self._type_convert(element.text)
else: else:
# if the fist child tag is list-item means all children are list-item # if the fist child tag is list-item means all children are list-
# item
if children[0].tag == "list-item": if children[0].tag == "list-item":
data = [] data = []
for child in children: for child in children:

View File

@ -59,7 +59,7 @@ class IsAuthenticatedOrReadOnly(BasePermission):
def has_permission(self, request, view, obj=None): def has_permission(self, request, view, obj=None):
if (request.method in SAFE_METHODS or if (request.method in SAFE_METHODS or
request.user and request.user and
request.user.is_authenticated()): request.user.is_authenticated()):
return True return True
return False return False
@ -109,6 +109,6 @@ class DjangoModelPermissions(BasePermission):
if (request.user and if (request.user and
request.user.is_authenticated() and request.user.is_authenticated() and
request.user.has_perms(perms, obj)): request.user.has_perms(perms, obj)):
return True return True
return False return False

View File

@ -35,9 +35,11 @@ class RelatedField(WritableField):
super(RelatedField, self).initialize(parent, field_name) super(RelatedField, self).initialize(parent, field_name)
if self.queryset is None and not self.read_only: if self.queryset is None and not self.read_only:
try: try:
manager = getattr(self.parent.opts.model, self.source or field_name) manager = getattr(
self.parent.opts.model, self.source or field_name)
if hasattr(manager, 'related'): # Forward if hasattr(manager, 'related'): # Forward
self.queryset = manager.related.model._default_manager.all() self.queryset = manager.related.model._default_manager.all(
)
else: # Reverse else: # Reverse
self.queryset = manager.field.rel.to._default_manager.all() self.queryset = manager.field.rel.to._default_manager.all()
except: except:
@ -194,13 +196,15 @@ class PrimaryKeyRelatedField(RelatedField):
return desc return desc
return "%s - %s" % (desc, ident) return "%s - %s" % (desc, ident)
# TODO: Possibly change this to just take `obj`, through prob less performant # TODO: Possibly change this to just take `obj`, through prob less
# performant
def to_native(self, pk): def to_native(self, pk):
return pk return pk
def from_native(self, data): def from_native(self, data):
if self.queryset is None: if self.queryset is None:
raise Exception('Writable related fields must include a `queryset` argument') raise Exception(
'Writable related fields must include a `queryset` argument')
try: try:
return self.queryset.get(pk=data) return self.queryset.get(pk=data)
@ -268,7 +272,8 @@ class ManyPrimaryKeyRelatedField(ManyRelatedField):
def from_native(self, data): def from_native(self, data):
if self.queryset is None: if self.queryset is None:
raise Exception('Writable related fields must include a `queryset` argument') raise Exception(
'Writable related fields must include a `queryset` argument')
try: try:
return self.queryset.get(pk=data) return self.queryset.get(pk=data)
@ -302,7 +307,8 @@ class SlugRelatedField(RelatedField):
def from_native(self, data): def from_native(self, data):
if self.queryset is None: if self.queryset is None:
raise Exception('Writable related fields must include a `queryset` argument') raise Exception(
'Writable related fields must include a `queryset` argument')
try: try:
return self.queryset.get(**{self.slug_field: data}) return self.queryset.get(**{self.slug_field: data})
@ -394,10 +400,12 @@ class HyperlinkedRelatedField(RelatedField):
# Convert URL -> model instance pk # Convert URL -> model instance pk
# TODO: Use values_list # TODO: Use values_list
if self.queryset is None: if self.queryset is None:
raise Exception('Writable related fields must include a `queryset` argument') raise Exception(
'Writable related fields must include a `queryset` argument')
try: try:
http_prefix = value.startswith('http:') or value.startswith('https:') http_prefix = value.startswith(
'http:') or value.startswith('https:')
except AttributeError: except AttributeError:
msg = self.error_messages['incorrect_type'] msg = self.error_messages['incorrect_type']
raise ValidationError(msg % type(value).__name__) raise ValidationError(msg % type(value).__name__)

View File

@ -33,7 +33,8 @@ class BaseRenderer(object):
format = None format = None
def render(self, data, accepted_media_type=None, renderer_context=None): def render(self, data, accepted_media_type=None, renderer_context=None):
raise NotImplemented('Renderer class requires .render() to be implemented') raise NotImplemented(
'Renderer class requires .render() to be implemented')
class JSONRenderer(BaseRenderer): class JSONRenderer(BaseRenderer):
@ -206,7 +207,8 @@ class TemplateHTMLRenderer(BaseRenderer):
return [self.template_name] return [self.template_name]
elif hasattr(view, 'get_template_names'): elif hasattr(view, 'get_template_names'):
return view.get_template_names() return view.get_template_names()
raise ConfigurationError('Returned a template response with no template_name') raise ConfigurationError(
'Returned a template response with no template_name')
def get_exception_template(self, response): def get_exception_template(self, response):
template_names = [name % {'status_code': response.status_code} template_names = [name % {'status_code': response.status_code}
@ -314,7 +316,7 @@ class BrowsableAPIRenderer(BaseRenderer):
kwargs = {} kwargs = {}
kwargs['required'] = v.required kwargs['required'] = v.required
#if getattr(v, 'queryset', None): # if getattr(v, 'queryset', None):
# kwargs['queryset'] = v.queryset # kwargs['queryset'] = v.queryset
if getattr(v, 'choices', None) is not None: if getattr(v, 'choices', None) is not None:
@ -356,7 +358,8 @@ class BrowsableAPIRenderer(BaseRenderer):
fields = self.serializer_to_form_fields(serializer) fields = self.serializer_to_form_fields(serializer)
# Creating an on the fly form see: # Creating an on the fly form see:
# http://stackoverflow.com/questions/3915024/dynamically-creating-classes-python # http://stackoverflow.com/questions/3915024/dynamically-creating-
# classes-python
OnTheFlyForm = type("OnTheFlyForm", (forms.Form,), fields) OnTheFlyForm = type("OnTheFlyForm", (forms.Form,), fields)
data = (obj is not None) and serializer.data or None data = (obj is not None) and serializer.data or None
form_instance = OnTheFlyForm(data) form_instance = OnTheFlyForm(data)
@ -370,7 +373,8 @@ class BrowsableAPIRenderer(BaseRenderer):
""" """
# If we're not using content overloading there's no point in supplying a generic form, # If we're not using content overloading there's no point in supplying a generic form,
# as the view won't treat the form's value as the content of the request. # as the view won't treat the form's value as the content of the
# request.
if not (api_settings.FORM_CONTENT_OVERRIDE if not (api_settings.FORM_CONTENT_OVERRIDE
and api_settings.FORM_CONTENTTYPE_OVERRIDE): and api_settings.FORM_CONTENTTYPE_OVERRIDE):
return None return None
@ -424,7 +428,8 @@ class BrowsableAPIRenderer(BaseRenderer):
response = renderer_context['response'] response = renderer_context['response']
renderer = self.get_default_renderer(view) renderer = self.get_default_renderer(view)
content = self.get_content(renderer, data, accepted_media_type, renderer_context) content = self.get_content(
renderer, data, accepted_media_type, renderer_context)
put_form = self.get_form(view, 'PUT', request) put_form = self.get_form(view, 'PUT', request)
post_form = self.get_form(view, 'POST', request) post_form = self.get_form(view, 'POST', request)

View File

@ -172,12 +172,12 @@ class Request(object):
@user.setter @user.setter
def user(self, value): def user(self, value):
""" """
Sets the user on the current request. This is necessary to maintain Sets the user on the current request. This is necessary to maintain
compatilbility with django.contrib.auth where the user proprety is compatilbility with django.contrib.auth where the user proprety is
set in the login and logout functions. set in the login and logout functions.
""" """
self._user = value self._user = value
@property @property
def auth(self): def auth(self):
@ -233,7 +233,7 @@ class Request(object):
""" """
try: try:
content_length = int(self.META.get('CONTENT_LENGTH', content_length = int(self.META.get('CONTENT_LENGTH',
self.META.get('HTTP_CONTENT_LENGTH'))) self.META.get('HTTP_CONTENT_LENGTH')))
except (ValueError, TypeError): except (ValueError, TypeError):
content_length = 0 content_length = 0
@ -259,23 +259,24 @@ class Request(object):
# We only need to use form overloading on form POST requests. # We only need to use form overloading on form POST requests.
if (not USE_FORM_OVERLOADING if (not USE_FORM_OVERLOADING
or self._request.method != 'POST' or self._request.method != 'POST'
or not is_form_media_type(self._content_type)): or not is_form_media_type(self._content_type)):
return return
# At this point we're committed to parsing the request as form data. # At this point we're committed to parsing the request as form data.
self._data = self._request.POST self._data = self._request.POST
self._files = self._request.FILES self._files = self._request.FILES
# Method overloading - change the method and remove the param from the content. # Method overloading - change the method and remove the param from the
# content.
if (self._METHOD_PARAM and if (self._METHOD_PARAM and
self._METHOD_PARAM in self._data): self._METHOD_PARAM in self._data):
self._method = self._data[self._METHOD_PARAM].upper() self._method = self._data[self._METHOD_PARAM].upper()
# Content overloading - modify the content type, and force re-parse. # Content overloading - modify the content type, and force re-parse.
if (self._CONTENT_PARAM and if (self._CONTENT_PARAM and
self._CONTENTTYPE_PARAM and self._CONTENTTYPE_PARAM and
self._CONTENT_PARAM in self._data and self._CONTENT_PARAM in self._data and
self._CONTENTTYPE_PARAM in self._data): self._CONTENTTYPE_PARAM in self._data):
self._content_type = self._data[self._CONTENTTYPE_PARAM] self._content_type = self._data[self._CONTENTTYPE_PARAM]
self._stream = StringIO(self._data[self._CONTENT_PARAM]) self._stream = StringIO(self._data[self._CONTENT_PARAM])
self._data, self._files = (Empty, Empty) self._data, self._files = (Empty, Empty)

View File

@ -22,9 +22,9 @@ class Response(SimpleTemplateResponse):
self.data = data self.data = data
self.template_name = template_name self.template_name = template_name
self.exception = exception self.exception = exception
if headers: if headers:
for name,value in headers.iteritems(): for name, value in headers.iteritems():
self[name] = value self[name] = value
@property @property

View File

@ -21,6 +21,7 @@ except ImportError:
print("Coverage is not installed. Aborting...") print("Coverage is not installed. Aborting...")
exit(1) exit(1)
def report(cov, cov_files): def report(cov, cov_files):
pc = cov.report(cov_files) pc = cov.report(cov_files)
@ -32,6 +33,7 @@ def report(cov, cov_files):
return pc return pc
def prepare_report(project_dir): def prepare_report(project_dir):
cov_files = [] cov_files = []
@ -52,7 +54,8 @@ def prepare_report(project_dir):
if 'rest_framework.py' in files: if 'rest_framework.py' in files:
files.remove('rest_framework.py') files.remove('rest_framework.py')
cov_files.extend([os.path.join(path, file) for file in files if file.endswith('.py')]) cov_files.extend([os.path.join(
path, file) for file in files if file.endswith('.py')])
return cov_files return cov_files
@ -96,7 +99,7 @@ def main():
report(cov, cov_files) report(cov, cov_files)
pc = report(cov, cov_files) pc = report(cov, cov_files)
if failures <> 0: if failures != 0:
sys.exit(failures) sys.exit(failures)
if pc < settings.CODE_COVERAGE_THRESHOLD: if pc < settings.CODE_COVERAGE_THRESHOLD:

View File

@ -3,7 +3,8 @@ import os
import sys import sys
sys.path.append(os.path.join(os.path.dirname(__file__), "../..")) sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "rest_framework.tests.settings") os.environ.setdefault(
"DJANGO_SETTINGS_MODULE", "rest_framework.tests.settings")
from django.core.management import execute_from_command_line from django.core.management import execute_from_command_line

View File

@ -28,7 +28,8 @@ class DictWithMetadata(dict):
Overriden to remove metadata from the dict, since it shouldn't be pickled Overriden to remove metadata from the dict, since it shouldn't be pickled
and may in some instances be unpickleable. and may in some instances be unpickleable.
""" """
# return an instance of the first dict in MRO that isn't a DictWithMetadata # return an instance of the first dict in MRO that isn't a
# DictWithMetadata
for base in self.__class__.__mro__: for base in self.__class__.__mro__:
if not isinstance(base, DictWithMetadata) and isinstance(base, dict): if not isinstance(base, DictWithMetadata) and isinstance(base, dict):
return base(self) return base(self)
@ -230,12 +231,14 @@ class BaseSerializer(Field):
if field_name in self._errors: if field_name in self._errors:
continue continue
try: try:
validate_method = getattr(self, 'validate_%s' % field_name, None) validate_method = getattr(
self, 'validate_%s' % field_name, None)
if validate_method: if validate_method:
source = field.source or field_name source = field.source or field_name
attrs = validate_method(attrs, source) attrs = validate_method(attrs, source)
except ValidationError as err: except ValidationError as err:
self._errors[field_name] = self._errors.get(field_name, []) + list(err.messages) self._errors[field_name] = self._errors.get(
field_name, []) + list(err.messages)
# If there are already errors, we don't run .validate() because # If there are already errors, we don't run .validate() because
# field-validation failed and thus `attrs` may not be complete. # field-validation failed and thus `attrs` may not be complete.
@ -246,7 +249,8 @@ class BaseSerializer(Field):
except ValidationError as err: except ValidationError as err:
if hasattr(err, 'message_dict'): if hasattr(err, 'message_dict'):
for field_name, error_messages in err.message_dict.items(): for field_name, error_messages in err.message_dict.items():
self._errors[field_name] = self._errors.get(field_name, []) + list(error_messages) self._errors[field_name] = self._errors.get(
field_name, []) + list(error_messages)
elif hasattr(err, 'messages'): elif hasattr(err, 'messages'):
self._errors['non_field_errors'] = err.messages self._errors['non_field_errors'] = err.messages

View File

@ -44,11 +44,11 @@ DEFAULTS = {
), ),
'DEFAULT_CONTENT_NEGOTIATION_CLASS': 'DEFAULT_CONTENT_NEGOTIATION_CLASS':
'rest_framework.negotiation.DefaultContentNegotiation', 'rest_framework.negotiation.DefaultContentNegotiation',
'DEFAULT_MODEL_SERIALIZER_CLASS': 'DEFAULT_MODEL_SERIALIZER_CLASS':
'rest_framework.serializers.ModelSerializer', 'rest_framework.serializers.ModelSerializer',
'DEFAULT_PAGINATION_SERIALIZER_CLASS': 'DEFAULT_PAGINATION_SERIALIZER_CLASS':
'rest_framework.pagination.PaginationSerializer', 'rest_framework.pagination.PaginationSerializer',
'DEFAULT_THROTTLE_RATES': { 'DEFAULT_THROTTLE_RATES': {
'user': None, 'user': None,
@ -116,7 +116,8 @@ def import_from_string(val, setting_name):
module = importlib.import_module(module_path) module = importlib.import_module(module_path)
return getattr(module, class_name) return getattr(module, class_name)
except ImportError as e: except ImportError as e:
msg = "Could not import '%s' for API setting '%s'. %s: %s." % (val, setting_name, e.__class__.__name__, e) msg = "Could not import '%s' for API setting '%s'. %s: %s." % (
val, setting_name, e.__class__.__name__, e)
raise ImportError(msg) raise ImportError(msg)

View File

@ -116,14 +116,17 @@ TRAILING_PUNCTUATION = ['.', ',', ')', '>', '\n', '&gt;', '"', "'"]
DOTS = ['&middot;', '*', '\xe2\x80\xa2', '&#149;', '&bull;', '&#8226;'] DOTS = ['&middot;', '*', '\xe2\x80\xa2', '&#149;', '&bull;', '&#8226;']
unencoded_ampersands_re = re.compile(r'&(?!(\w+|#\d+);)') unencoded_ampersands_re = re.compile(r'&(?!(\w+|#\d+);)')
word_split_re = re.compile(r'(\s+)') word_split_re = re.compile(r'(\s+)')
punctuation_re = re.compile('^(?P<lead>(?:%s)*)(?P<middle>.*?)(?P<trail>(?:%s)*)$' % \ punctuation_re = re.compile('^(?P<lead>(?:%s)*)(?P<middle>.*?)(?P<trail>(?:%s)*)$' %
('|'.join([re.escape(x) for x in LEADING_PUNCTUATION]), (
'|'.join([re.escape(x) for x in TRAILING_PUNCTUATION]))) '|'.join([re.escape(
x) for x in LEADING_PUNCTUATION]),
'|'.join([re.escape(x) for x in TRAILING_PUNCTUATION])))
simple_email_re = re.compile(r'^\S+@[a-zA-Z0-9._-]+\.[a-zA-Z0-9._-]+$') simple_email_re = re.compile(r'^\S+@[a-zA-Z0-9._-]+\.[a-zA-Z0-9._-]+$')
link_target_attribute_re = re.compile(r'(<a [^>]*?)target=[^\s>]+') link_target_attribute_re = re.compile(r'(<a [^>]*?)target=[^\s>]+')
html_gunk_re = re.compile(r'(?:<br clear="all">|<i><\/i>|<b><\/b>|<em><\/em>|<strong><\/strong>|<\/?smallcaps>|<\/?uppercase>)', re.IGNORECASE) html_gunk_re = re.compile(r'(?:<br clear="all">|<i><\/i>|<b><\/b>|<em><\/em>|<strong><\/strong>|<\/?smallcaps>|<\/?uppercase>)', re.IGNORECASE)
hard_coded_bullets_re = re.compile(r'((?:<p>(?:%s).*?[a-zA-Z].*?</p>\s*)+)' % '|'.join([re.escape(x) for x in DOTS]), re.DOTALL) hard_coded_bullets_re = re.compile(r'((?:<p>(?:%s).*?[a-zA-Z].*?</p>\s*)+)' % '|'.join([re.escape(x) for x in DOTS]), re.DOTALL)
trailing_empty_content_re = re.compile(r'(?:<p>(?:&nbsp;|\s|<br \/>)*?</p>\s*)+\Z') trailing_empty_content_re = re.compile(
r'(?:<p>(?:&nbsp;|\s|<br \/>)*?</p>\s*)+\Z')
# And the template tags themselves... # And the template tags themselves...
@ -211,7 +214,8 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
If autoescape is True, the link text and URLs will get autoescaped. If autoescape is True, the link text and URLs will get autoescaped.
""" """
trim_url = lambda x, limit=trim_url_limit: limit is not None and (len(x) > limit and ('%s...' % x[:max(0, limit - 3)])) or x trim_url = lambda x, limit=trim_url_limit: limit is not None and (
len(x) > limit and ('%s...' % x[:max(0, limit - 3)])) or x
safe_input = isinstance(text, SafeData) safe_input = isinstance(text, SafeData)
words = word_split_re.split(force_unicode(text)) words = word_split_re.split(force_unicode(text))
nofollow_attr = nofollow and ' rel="nofollow"' or '' nofollow_attr = nofollow and ' rel="nofollow"' or ''
@ -225,9 +229,9 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
url = None url = None
if middle.startswith('http://') or middle.startswith('https://'): if middle.startswith('http://') or middle.startswith('https://'):
url = middle url = middle
elif middle.startswith('www.') or ('@' not in middle and \ elif middle.startswith('www.') or ('@' not in middle and
middle and middle[0] in string.ascii_letters + string.digits and \ middle and middle[0] in string.ascii_letters + string.digits and
(middle.endswith('.org') or middle.endswith('.net') or middle.endswith('.com'))): (middle.endswith('.org') or middle.endswith('.net') or middle.endswith('.com'))):
url = 'http://%s' % middle url = 'http://%s' % middle
elif '@' in middle and not ':' in middle and simple_email_re.match(middle): elif '@' in middle and not ':' in middle and simple_email_re.match(middle):
url = 'mailto:%s' % middle url = 'mailto:%s' % middle
@ -238,7 +242,8 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
if autoescape and not safe_input: if autoescape and not safe_input:
lead, trail = escape(lead), escape(trail) lead, trail = escape(lead), escape(trail)
url, trimmed = escape(url), escape(trimmed) url, trimmed = escape(url), escape(trimmed)
middle = '<a href="%s"%s>%s</a>' % (url, nofollow_attr, trimmed) middle = '<a href="%s"%s>%s</a>' % (
url, nofollow_attr, trimmed)
words[i] = mark_safe('%s%s%s' % (lead, middle, trail)) words[i] = mark_safe('%s%s%s' % (lead, middle, trail))
else: else:
if safe_input: if safe_input:

View File

@ -55,7 +55,8 @@
# group.save() # group.save()
# self.assertEqual(0, User.objects.count()) # self.assertEqual(0, User.objects.count())
# response = self.client.post('/users/', {'username': 'bar', 'password': 'baz', 'groups': [group.id]}) # response = self.client.post('/users/', {'username': 'bar', 'password':
# 'baz', 'groups': [group.id]})
# self.assertEqual(response.status_code, 201) # self.assertEqual(response.status_code, 201)
# self.assertEqual(1, User.objects.count()) # self.assertEqual(1, User.objects.count())
@ -77,7 +78,8 @@
# group.save() # group.save()
# self.assertEqual(0, User.objects.count()) # self.assertEqual(0, User.objects.count())
# response = self.client.post('/customusers/', {'username': 'bar', 'groups': [group.id]}) # response = self.client.post('/customusers/', {'username': 'bar',
# 'groups': [group.id]})
# self.assertEqual(response.status_code, 201) # self.assertEqual(response.status_code, 201)
# self.assertEqual(1, CustomUser.objects.count()) # self.assertEqual(1, CustomUser.objects.count())

View File

@ -5,7 +5,7 @@ TEMPLATE_DEBUG = DEBUG
DEBUG_PROPAGATE_EXCEPTIONS = True DEBUG_PROPAGATE_EXCEPTIONS = True
ADMINS = ( ADMINS = (
# ('Your Name', 'your_email@domain.com'), # ('Your Name', 'your_email@domain.com'),
) )
MANAGERS = ADMINS MANAGERS = ADMINS
@ -13,12 +13,13 @@ MANAGERS = ADMINS
DATABASES = { DATABASES = {
'default': { 'default': {
'ENGINE': 'django.db.backends.sqlite3', 'ENGINE': 'django.db.backends.sqlite3',
# Add 'postgresql_psycopg2', 'postgresql', 'mysql', 'sqlite3' or 'oracle'. # Add 'postgresql_psycopg2', 'postgresql', 'mysql', 'sqlite3' or
'NAME': 'sqlite.db', # Or path to database file if using sqlite3. # 'oracle'.
'USER': '', # Not used with sqlite3. 'NAME': 'sqlite.db', # Or path to database file if using sqlite3.
'PASSWORD': '', # Not used with sqlite3. 'USER': '', # Not used with sqlite3.
'HOST': '', # Set to empty string for localhost. Not used with sqlite3. 'PASSWORD': '', # Not used with sqlite3.
'PORT': '', # Set to empty string for default. Not used with sqlite3. 'HOST': '', # Set to empty string for localhost. Not used with sqlite3.
'PORT': '', # Set to empty string for default. Not used with sqlite3.
} }
} }
@ -68,7 +69,7 @@ TEMPLATE_LOADERS = (
'django.template.loaders.filesystem.Loader', 'django.template.loaders.filesystem.Loader',
'django.template.loaders.app_directories.Loader', 'django.template.loaders.app_directories.Loader',
# 'django.template.loaders.eggs.Loader', # 'django.template.loaders.eggs.Loader',
) )
MIDDLEWARE_CLASSES = ( MIDDLEWARE_CLASSES = (
'django.middleware.common.CommonMiddleware', 'django.middleware.common.CommonMiddleware',
@ -76,14 +77,14 @@ MIDDLEWARE_CLASSES = (
'django.middleware.csrf.CsrfViewMiddleware', 'django.middleware.csrf.CsrfViewMiddleware',
'django.contrib.auth.middleware.AuthenticationMiddleware', 'django.contrib.auth.middleware.AuthenticationMiddleware',
'django.contrib.messages.middleware.MessageMiddleware', 'django.contrib.messages.middleware.MessageMiddleware',
) )
ROOT_URLCONF = 'urls' ROOT_URLCONF = 'urls'
TEMPLATE_DIRS = ( TEMPLATE_DIRS = (
# Put strings here, like "/home/html/django_templates" or "C:/www/django/templates". # Put strings here, like "/home/html/django_templates" or "C:/www/django/templates".
# Always use forward slashes, even on Windows. # Always use forward slashes, even on Windows.
# Don't forget to use absolute paths, not relative paths. # Don't forget to use absolute paths, not relative paths.
) )
INSTALLED_APPS = ( INSTALLED_APPS = (
@ -100,7 +101,7 @@ INSTALLED_APPS = (
'rest_framework.authtoken', 'rest_framework.authtoken',
'rest_framework.tests', 'rest_framework.tests',
'discover_runner' 'discover_runner'
) )
TEST_RUNNER = 'discover_runner.runner.DiscoverRunner' TEST_RUNNER = 'discover_runner.runner.DiscoverRunner'
@ -114,7 +115,8 @@ if django.VERSION < (1, 3):
INSTALLED_APPS += ('staticfiles',) INSTALLED_APPS += ('staticfiles',)
# If we're running on the Jenkins server we want to archive the coverage reports as XML. # If we're running on the Jenkins server we want to archive the coverage
# reports as XML.
import os import os
if os.environ.get('HUDSON_URL', None): if os.environ.get('HUDSON_URL', None):

View File

@ -22,11 +22,15 @@ class MockView(APIView):
return HttpResponse({'a': 1, 'b': 2, 'c': 3}) return HttpResponse({'a': 1, 'b': 2, 'c': 3})
urlpatterns = patterns('', urlpatterns = patterns('',
(r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])), (r'^session/$', MockView.as_view(authentication_classes=[
(r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])), SessionAuthentication])),
(r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])), (r'^basic/$', MockView.as_view(authentication_classes=[
(r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'), BasicAuthentication])),
) (r'^token/$', MockView.as_view(authentication_classes=[
TokenAuthentication])),
(r'^auth-token/$',
'rest_framework.authtoken.views.obtain_auth_token'),
)
class BasicAuthTests(TestCase): class BasicAuthTests(TestCase):
@ -38,18 +42,23 @@ class BasicAuthTests(TestCase):
self.username = 'john' self.username = 'john'
self.email = 'lennon@thebeatles.com' self.email = 'lennon@thebeatles.com'
self.password = 'password' self.password = 'password'
self.user = User.objects.create_user(self.username, self.email, self.password) self.user = User.objects.create_user(
self.username, self.email, self.password)
def test_post_form_passing_basic_auth(self): def test_post_form_passing_basic_auth(self):
"""Ensure POSTing json over basic auth with correct credentials passes and does not require CSRF""" """Ensure POSTing json over basic auth with correct credentials passes and does not require CSRF"""
auth = 'Basic %s' % base64.encodestring('%s:%s' % (self.username, self.password)).strip() auth = 'Basic %s' % base64.encodestring(
response = self.csrf_client.post('/basic/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) '%s:%s' % (self.username, self.password)).strip()
response = self.csrf_client.post(
'/basic/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_post_json_passing_basic_auth(self): def test_post_json_passing_basic_auth(self):
"""Ensure POSTing form over basic auth with correct credentials passes and does not require CSRF""" """Ensure POSTing form over basic auth with correct credentials passes and does not require CSRF"""
auth = 'Basic %s' % base64.encodestring('%s:%s' % (self.username, self.password)).strip() auth = 'Basic %s' % base64.encodestring(
response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json', '%s:%s' % (self.username, self.password)).strip()
response = self.csrf_client.post(
'/basic/', json.dumps({'example': 'example'}), 'application/json',
HTTP_AUTHORIZATION=auth) HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -60,7 +69,8 @@ class BasicAuthTests(TestCase):
def test_post_json_failing_basic_auth(self): def test_post_json_failing_basic_auth(self):
"""Ensure POSTing json over basic auth without correct credentials fails""" """Ensure POSTing json over basic auth without correct credentials fails"""
response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json') response = self.csrf_client.post('/basic/', json.dumps(
{'example': 'example'}), 'application/json')
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
self.assertEqual(response['WWW-Authenticate'], 'Basic realm="api"') self.assertEqual(response['WWW-Authenticate'], 'Basic realm="api"')
@ -75,7 +85,8 @@ class SessionAuthTests(TestCase):
self.username = 'john' self.username = 'john'
self.email = 'lennon@thebeatles.com' self.email = 'lennon@thebeatles.com'
self.password = 'password' self.password = 'password'
self.user = User.objects.create_user(self.username, self.email, self.password) self.user = User.objects.create_user(
self.username, self.email, self.password)
def tearDown(self): def tearDown(self):
self.csrf_client.logout() self.csrf_client.logout()
@ -92,16 +103,20 @@ class SessionAuthTests(TestCase):
""" """
Ensure POSTing form over session authentication with logged in user and CSRF token passes. Ensure POSTing form over session authentication with logged in user and CSRF token passes.
""" """
self.non_csrf_client.login(username=self.username, password=self.password) self.non_csrf_client.login(
response = self.non_csrf_client.post('/session/', {'example': 'example'}) username=self.username, password=self.password)
response = self.non_csrf_client.post(
'/session/', {'example': 'example'})
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_put_form_session_auth_passing(self): def test_put_form_session_auth_passing(self):
""" """
Ensure PUTting form over session authentication with logged in user and CSRF token passes. Ensure PUTting form over session authentication with logged in user and CSRF token passes.
""" """
self.non_csrf_client.login(username=self.username, password=self.password) self.non_csrf_client.login(
response = self.non_csrf_client.put('/session/', {'example': 'example'}) username=self.username, password=self.password)
response = self.non_csrf_client.put(
'/session/', {'example': 'example'})
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_post_form_session_auth_failing(self): def test_post_form_session_auth_failing(self):
@ -121,7 +136,8 @@ class TokenAuthTests(TestCase):
self.username = 'john' self.username = 'john'
self.email = 'lennon@thebeatles.com' self.email = 'lennon@thebeatles.com'
self.password = 'password' self.password = 'password'
self.user = User.objects.create_user(self.username, self.email, self.password) self.user = User.objects.create_user(
self.username, self.email, self.password)
self.key = 'abcd1234' self.key = 'abcd1234'
self.token = Token.objects.create(key=self.key, user=self.user) self.token = Token.objects.create(key=self.key, user=self.user)
@ -129,13 +145,15 @@ class TokenAuthTests(TestCase):
def test_post_form_passing_token_auth(self): def test_post_form_passing_token_auth(self):
"""Ensure POSTing json over token auth with correct credentials passes and does not require CSRF""" """Ensure POSTing json over token auth with correct credentials passes and does not require CSRF"""
auth = "Token " + self.key auth = "Token " + self.key
response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) response = self.csrf_client.post(
'/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_post_json_passing_token_auth(self): def test_post_json_passing_token_auth(self):
"""Ensure POSTing form over token auth with correct credentials passes and does not require CSRF""" """Ensure POSTing form over token auth with correct credentials passes and does not require CSRF"""
auth = "Token " + self.key auth = "Token " + self.key
response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json', response = self.csrf_client.post(
'/token/', json.dumps({'example': 'example'}), 'application/json',
HTTP_AUTHORIZATION=auth) HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -146,7 +164,8 @@ class TokenAuthTests(TestCase):
def test_post_json_failing_token_auth(self): def test_post_json_failing_token_auth(self):
"""Ensure POSTing json over token auth without correct credentials fails""" """Ensure POSTing json over token auth without correct credentials fails"""
response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json') response = self.csrf_client.post('/token/', json.dumps(
{'example': 'example'}), 'application/json')
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
def test_token_has_auto_assigned_key_if_none_provided(self): def test_token_has_auto_assigned_key_if_none_provided(self):
@ -159,7 +178,7 @@ class TokenAuthTests(TestCase):
"""Ensure token login view using JSON POST works.""" """Ensure token login view using JSON POST works."""
client = Client(enforce_csrf_checks=True) client = Client(enforce_csrf_checks=True)
response = client.post('/auth-token/', response = client.post('/auth-token/',
json.dumps({'username': self.username, 'password': self.password}), 'application/json') json.dumps({'username': self.username, 'password': self.password}), 'application/json')
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(json.loads(response.content)['token'], self.key) self.assertEqual(json.loads(response.content)['token'], self.key)
@ -167,20 +186,20 @@ class TokenAuthTests(TestCase):
"""Ensure token login view using JSON POST fails if bad credentials are used.""" """Ensure token login view using JSON POST fails if bad credentials are used."""
client = Client(enforce_csrf_checks=True) client = Client(enforce_csrf_checks=True)
response = client.post('/auth-token/', response = client.post('/auth-token/',
json.dumps({'username': self.username, 'password': "badpass"}), 'application/json') json.dumps({'username': self.username, 'password': "badpass"}), 'application/json')
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
def test_token_login_json_missing_fields(self): def test_token_login_json_missing_fields(self):
"""Ensure token login view using JSON POST fails if missing fields.""" """Ensure token login view using JSON POST fails if missing fields."""
client = Client(enforce_csrf_checks=True) client = Client(enforce_csrf_checks=True)
response = client.post('/auth-token/', response = client.post('/auth-token/',
json.dumps({'username': self.username}), 'application/json') json.dumps({'username': self.username}), 'application/json')
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
def test_token_login_form(self): def test_token_login_form(self):
"""Ensure token login view using form POST works.""" """Ensure token login view using form POST works."""
client = Client(enforce_csrf_checks=True) client = Client(enforce_csrf_checks=True)
response = client.post('/auth-token/', response = client.post('/auth-token/',
{'username': self.username, 'password': self.password}) {'username': self.username, 'password': self.password})
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(json.loads(response.content)['token'], self.key) self.assertEqual(json.loads(response.content)['token'], self.key)

View File

@ -24,12 +24,15 @@ class NestedResourceInstance(APIView):
pass pass
urlpatterns = patterns('', urlpatterns = patterns('',
url(r'^$', Root.as_view()), url(r'^$', Root.as_view()),
url(r'^resource/$', ResourceRoot.as_view()), url(r'^resource/$', ResourceRoot.as_view()),
url(r'^resource/(?P<key>[0-9]+)$', ResourceInstance.as_view()), url(r'^resource/(?P<key>[0-9]+)$',
url(r'^resource/(?P<key>[0-9]+)/$', NestedResourceRoot.as_view()), ResourceInstance.as_view()),
url(r'^resource/(?P<key>[0-9]+)/(?P<other>[A-Za-z]+)$', NestedResourceInstance.as_view()), url(r'^resource/(?P<key>[0-9]+)/$',
) NestedResourceRoot.as_view()),
url(r'^resource/(?P<key>[0-9]+)/(?P<other>[A-Za-z]+)$',
NestedResourceInstance.as_view()),
)
class BreadcrumbTests(TestCase): class BreadcrumbTests(TestCase):
@ -49,22 +52,28 @@ class BreadcrumbTests(TestCase):
def test_resource_instance_breadcrumbs(self): def test_resource_instance_breadcrumbs(self):
url = '/resource/123' url = '/resource/123'
self.assertEqual(get_breadcrumbs(url), [('Root', '/'), self.assertEqual(get_breadcrumbs(url), [('Root', '/'),
('Resource Root', '/resource/'), ('Resource Root',
'/resource/'),
('Resource Instance', '/resource/123')]) ('Resource Instance', '/resource/123')])
def test_nested_resource_breadcrumbs(self): def test_nested_resource_breadcrumbs(self):
url = '/resource/123/' url = '/resource/123/'
self.assertEqual(get_breadcrumbs(url), [('Root', '/'), self.assertEqual(get_breadcrumbs(url), [('Root', '/'),
('Resource Root', '/resource/'), ('Resource Root',
('Resource Instance', '/resource/123'), '/resource/'),
('Resource Instance',
'/resource/123'),
('Nested Resource Root', '/resource/123/')]) ('Nested Resource Root', '/resource/123/')])
def test_nested_resource_instance_breadcrumbs(self): def test_nested_resource_instance_breadcrumbs(self):
url = '/resource/123/abc' url = '/resource/123/abc'
self.assertEqual(get_breadcrumbs(url), [('Root', '/'), self.assertEqual(get_breadcrumbs(url), [('Root', '/'),
('Resource Root', '/resource/'), ('Resource Root',
('Resource Instance', '/resource/123'), '/resource/'),
('Nested Resource Root', '/resource/123/'), ('Resource Instance',
'/resource/123'),
('Nested Resource Root',
'/resource/123/'),
('Nested Resource Instance', '/resource/123/abc')]) ('Nested Resource Instance', '/resource/123/abc')])
def test_broken_url_breadcrumbs_handled_gracefully(self): def test_broken_url_breadcrumbs_handled_gracefully(self):

View File

@ -14,7 +14,7 @@ from rest_framework.decorators import (
authentication_classes, authentication_classes,
throttle_classes, throttle_classes,
permission_classes, permission_classes,
) )
from rest_framework.tests.utils import RequestFactory from rest_framework.tests.utils import RequestFactory
@ -60,7 +60,8 @@ class DecoratorTestCase(TestCase):
request = self.factory.post('/') request = self.factory.post('/')
response = view(request) response = view(request)
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) self.assertEqual(
response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
def test_calling_put_method(self): def test_calling_put_method(self):
@api_view(['GET', 'PUT']) @api_view(['GET', 'PUT'])
@ -73,7 +74,8 @@ class DecoratorTestCase(TestCase):
request = self.factory.post('/') request = self.factory.post('/')
response = view(request) response = view(request)
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) self.assertEqual(
response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
def test_calling_patch_method(self): def test_calling_patch_method(self):
@api_view(['GET', 'PATCH']) @api_view(['GET', 'PATCH'])
@ -86,7 +88,8 @@ class DecoratorTestCase(TestCase):
request = self.factory.post('/') request = self.factory.post('/')
response = view(request) response = view(request)
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) self.assertEqual(
response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
def test_renderer_classes(self): def test_renderer_classes(self):
@api_view(['GET']) @api_view(['GET'])
@ -104,7 +107,7 @@ class DecoratorTestCase(TestCase):
def view(request): def view(request):
self.assertEqual(len(request.parsers), 1) self.assertEqual(len(request.parsers), 1)
self.assertTrue(isinstance(request.parsers[0], self.assertTrue(isinstance(request.parsers[0],
JSONParser)) JSONParser))
return Response({}) return Response({})
request = self.factory.get('/') request = self.factory.get('/')
@ -116,7 +119,7 @@ class DecoratorTestCase(TestCase):
def view(request): def view(request):
self.assertEqual(len(request.authenticators), 1) self.assertEqual(len(request.authenticators), 1)
self.assertTrue(isinstance(request.authenticators[0], self.assertTrue(isinstance(request.authenticators[0],
BasicAuthentication)) BasicAuthentication))
return Response({}) return Response({})
request = self.factory.get('/') request = self.factory.get('/')
@ -146,4 +149,5 @@ class DecoratorTestCase(TestCase):
self.assertEquals(response.status_code, status.HTTP_200_OK) self.assertEquals(response.status_code, status.HTTP_200_OK)
response = view(request) response = view(request)
self.assertEquals(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS) self.assertEquals(
response.status_code, status.HTTP_429_TOO_MANY_REQUESTS)

View File

@ -30,7 +30,8 @@ class FileSerializerTests(TestCase):
file = StringIO.StringIO('stuff') file = StringIO.StringIO('stuff')
file.name = 'stuff.txt' file.name = 'stuff.txt'
file.size = file.len file.size = file.len
serializer = UploadedFileSerializer(data={'created': now}, files={'file': file}) serializer = UploadedFileSerializer(
data={'created': now}, files={'file': file})
uploaded_file = UploadedFile(file=file, created=now) uploaded_file = UploadedFile(file=file, created=now)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
self.assertEquals(serializer.object.created, uploaded_file.created) self.assertEquals(serializer.object.created, uploaded_file.created)

View File

@ -56,15 +56,17 @@ class IntegrationTestFiltering(TestCase):
""" """
base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8)) base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8))
for i in range(10): for i in range(10):
text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc. text = chr(i + ord(
base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc.
decimal = base_data[1] + i decimal = base_data[1] + i
date = base_data[2] - datetime.timedelta(days=i * 2) date = base_data[2] - datetime.timedelta(days=i * 2)
FilterableItem(text=text, decimal=decimal, date=date).save() FilterableItem(text=text, decimal=decimal, date=date).save()
self.objects = FilterableItem.objects self.objects = FilterableItem.objects
self.data = [ self.data = [
{'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal,
for obj in self.objects.all() 'date': obj.date}
for obj in self.objects.all()
] ]
@unittest.skipUnless(django_filters, 'django-filters not installed') @unittest.skipUnless(django_filters, 'django-filters not installed')
@ -85,12 +87,14 @@ class IntegrationTestFiltering(TestCase):
request = factory.get('/?decimal=%s' % search_decimal) request = factory.get('/?decimal=%s' % search_decimal)
response = view(request).render() response = view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK) self.assertEquals(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['decimal'] == search_decimal] expected_data = [f for f in self.data if f['decimal']
== search_decimal]
self.assertEquals(response.data, expected_data) self.assertEquals(response.data, expected_data)
# Tests that the date filter works. # Tests that the date filter works.
search_date = datetime.date(2012, 9, 22) search_date = datetime.date(2012, 9, 22)
request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22' request = factory.get(
'/?date=%s' % search_date) # search_date str: '2012-09-22'
response = view(request).render() response = view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK) self.assertEquals(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['date'] == search_date] expected_data = [f for f in self.data if f['date'] == search_date]
@ -110,7 +114,8 @@ class IntegrationTestFiltering(TestCase):
self.assertEquals(response.status_code, status.HTTP_200_OK) self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data, self.data) self.assertEquals(response.data, self.data)
# Tests that the decimal filter set with 'lt' in the filter class works. # Tests that the decimal filter set with 'lt' in the filter class
# works.
search_decimal = Decimal('4.25') search_decimal = Decimal('4.25')
request = factory.get('/?decimal=%s' % search_decimal) request = factory.get('/?decimal=%s' % search_decimal)
response = view(request).render() response = view(request).render()
@ -120,28 +125,32 @@ class IntegrationTestFiltering(TestCase):
# Tests that the date filter set with 'gt' in the filter class works. # Tests that the date filter set with 'gt' in the filter class works.
search_date = datetime.date(2012, 10, 2) search_date = datetime.date(2012, 10, 2)
request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02' request = factory.get(
'/?date=%s' % search_date) # search_date str: '2012-10-02'
response = view(request).render() response = view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK) self.assertEquals(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['date'] > search_date] expected_data = [f for f in self.data if f['date'] > search_date]
self.assertEquals(response.data, expected_data) self.assertEquals(response.data, expected_data)
# Tests that the text filter set with 'icontains' in the filter class works. # Tests that the text filter set with 'icontains' in the filter class
# works.
search_text = 'ff' search_text = 'ff'
request = factory.get('/?text=%s' % search_text) request = factory.get('/?text=%s' % search_text)
response = view(request).render() response = view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK) self.assertEquals(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if search_text in f['text'].lower()] expected_data = [f for f in self.data if search_text in f[
'text'].lower()]
self.assertEquals(response.data, expected_data) self.assertEquals(response.data, expected_data)
# Tests that multiple filters works. # Tests that multiple filters works.
search_decimal = Decimal('5.25') search_decimal = Decimal('5.25')
search_date = datetime.date(2012, 10, 2) search_date = datetime.date(2012, 10, 2)
request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date)) request = factory.get(
'/?decimal=%s&date=%s' % (search_decimal, search_date))
response = view(request).render() response = view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK) self.assertEquals(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['date'] > search_date and expected_data = [f for f in self.data if f['date'] > search_date and
f['decimal'] < search_decimal] f['decimal'] < search_decimal]
self.assertEquals(response.data, expected_data) self.assertEquals(response.data, expected_data)
@unittest.skipUnless(django_filters, 'django-filters not installed') @unittest.skipUnless(django_filters, 'django-filters not installed')

View File

@ -49,8 +49,8 @@ class TestRootView(TestCase):
BasicModel(text=item).save() BasicModel(text=item).save()
self.objects = BasicModel.objects self.objects = BasicModel.objects
self.data = [ self.data = [
{'id': obj.id, 'text': obj.text} {'id': obj.id, 'text': obj.text}
for obj in self.objects.all() for obj in self.objects.all()
] ]
self.view = RootView.as_view() self.view = RootView.as_view()
@ -69,7 +69,7 @@ class TestRootView(TestCase):
""" """
content = {'text': 'foobar'} content = {'text': 'foobar'}
request = factory.post('/', json.dumps(content), request = factory.post('/', json.dumps(content),
content_type='application/json') content_type='application/json')
response = self.view(request).render() response = self.view(request).render()
self.assertEquals(response.status_code, status.HTTP_201_CREATED) self.assertEquals(response.status_code, status.HTTP_201_CREATED)
self.assertEquals(response.data, {'id': 4, 'text': u'foobar'}) self.assertEquals(response.data, {'id': 4, 'text': u'foobar'})
@ -82,10 +82,12 @@ class TestRootView(TestCase):
""" """
content = {'text': 'foobar'} content = {'text': 'foobar'}
request = factory.put('/', json.dumps(content), request = factory.put('/', json.dumps(content),
content_type='application/json') content_type='application/json')
response = self.view(request).render() response = self.view(request).render()
self.assertEquals(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) self.assertEquals(
self.assertEquals(response.data, {"detail": "Method 'PUT' not allowed."}) response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
self.assertEquals(
response.data, {"detail": "Method 'PUT' not allowed."})
def test_delete_root_view(self): def test_delete_root_view(self):
""" """
@ -93,8 +95,10 @@ class TestRootView(TestCase):
""" """
request = factory.delete('/') request = factory.delete('/')
response = self.view(request).render() response = self.view(request).render()
self.assertEquals(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) self.assertEquals(
self.assertEquals(response.data, {"detail": "Method 'DELETE' not allowed."}) response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
self.assertEquals(
response.data, {"detail": "Method 'DELETE' not allowed."})
def test_options_root_view(self): def test_options_root_view(self):
""" """
@ -124,7 +128,7 @@ class TestRootView(TestCase):
""" """
content = {'id': 999, 'text': 'foobar'} content = {'id': 999, 'text': 'foobar'}
request = factory.post('/', json.dumps(content), request = factory.post('/', json.dumps(content),
content_type='application/json') content_type='application/json')
response = self.view(request).render() response = self.view(request).render()
self.assertEquals(response.status_code, status.HTTP_201_CREATED) self.assertEquals(response.status_code, status.HTTP_201_CREATED)
self.assertEquals(response.data, {'id': 4, 'text': u'foobar'}) self.assertEquals(response.data, {'id': 4, 'text': u'foobar'})
@ -142,8 +146,8 @@ class TestInstanceView(TestCase):
BasicModel(text=item).save() BasicModel(text=item).save()
self.objects = BasicModel.objects self.objects = BasicModel.objects
self.data = [ self.data = [
{'id': obj.id, 'text': obj.text} {'id': obj.id, 'text': obj.text}
for obj in self.objects.all() for obj in self.objects.all()
] ]
self.view = InstanceView.as_view() self.view = InstanceView.as_view()
self.slug_based_view = SlugBasedInstanceView.as_view() self.slug_based_view = SlugBasedInstanceView.as_view()
@ -163,10 +167,12 @@ class TestInstanceView(TestCase):
""" """
content = {'text': 'foobar'} content = {'text': 'foobar'}
request = factory.post('/', json.dumps(content), request = factory.post('/', json.dumps(content),
content_type='application/json') content_type='application/json')
response = self.view(request).render() response = self.view(request).render()
self.assertEquals(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) self.assertEquals(
self.assertEquals(response.data, {"detail": "Method 'POST' not allowed."}) response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
self.assertEquals(
response.data, {"detail": "Method 'POST' not allowed."})
def test_put_instance_view(self): def test_put_instance_view(self):
""" """
@ -174,7 +180,7 @@ class TestInstanceView(TestCase):
""" """
content = {'text': 'foobar'} content = {'text': 'foobar'}
request = factory.put('/1', json.dumps(content), request = factory.put('/1', json.dumps(content),
content_type='application/json') content_type='application/json')
response = self.view(request, pk='1').render() response = self.view(request, pk='1').render()
self.assertEquals(response.status_code, status.HTTP_200_OK) self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data, {'id': 1, 'text': 'foobar'}) self.assertEquals(response.data, {'id': 1, 'text': 'foobar'})
@ -187,7 +193,7 @@ class TestInstanceView(TestCase):
""" """
content = {'text': 'foobar'} content = {'text': 'foobar'}
request = factory.patch('/1', json.dumps(content), request = factory.patch('/1', json.dumps(content),
content_type='application/json') content_type='application/json')
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
self.assertEquals(response.status_code, status.HTTP_200_OK) self.assertEquals(response.status_code, status.HTTP_200_OK)
@ -234,7 +240,7 @@ class TestInstanceView(TestCase):
""" """
content = {'id': 999, 'text': 'foobar'} content = {'id': 999, 'text': 'foobar'}
request = factory.put('/1', json.dumps(content), request = factory.put('/1', json.dumps(content),
content_type='application/json') content_type='application/json')
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
self.assertEquals(response.status_code, status.HTTP_200_OK) self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data, {'id': 1, 'text': 'foobar'}) self.assertEquals(response.data, {'id': 1, 'text': 'foobar'})
@ -249,7 +255,7 @@ class TestInstanceView(TestCase):
self.objects.get(id=1).delete() self.objects.get(id=1).delete()
content = {'text': 'foobar'} content = {'text': 'foobar'}
request = factory.put('/1', json.dumps(content), request = factory.put('/1', json.dumps(content),
content_type='application/json') content_type='application/json')
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
self.assertEquals(response.status_code, status.HTTP_201_CREATED) self.assertEquals(response.status_code, status.HTTP_201_CREATED)
self.assertEquals(response.data, {'id': 1, 'text': 'foobar'}) self.assertEquals(response.data, {'id': 1, 'text': 'foobar'})
@ -262,9 +268,10 @@ class TestInstanceView(TestCase):
at the requested url if it doesn't exist. at the requested url if it doesn't exist.
""" """
content = {'text': 'foobar'} content = {'text': 'foobar'}
# pk fields can not be created on demand, only the database can set th pk for a new object # pk fields can not be created on demand, only the database can set th
# pk for a new object
request = factory.put('/5', json.dumps(content), request = factory.put('/5', json.dumps(content),
content_type='application/json') content_type='application/json')
response = self.view(request, pk=5).render() response = self.view(request, pk=5).render()
self.assertEquals(response.status_code, status.HTTP_201_CREATED) self.assertEquals(response.status_code, status.HTTP_201_CREATED)
new_obj = self.objects.get(pk=5) new_obj = self.objects.get(pk=5)
@ -277,10 +284,11 @@ class TestInstanceView(TestCase):
""" """
content = {'text': 'foobar'} content = {'text': 'foobar'}
request = factory.put('/test_slug', json.dumps(content), request = factory.put('/test_slug', json.dumps(content),
content_type='application/json') content_type='application/json')
response = self.slug_based_view(request, slug='test_slug').render() response = self.slug_based_view(request, slug='test_slug').render()
self.assertEquals(response.status_code, status.HTTP_201_CREATED) self.assertEquals(response.status_code, status.HTTP_201_CREATED)
self.assertEquals(response.data, {'slug': 'test_slug', 'text': 'foobar'}) self.assertEquals(
response.data, {'slug': 'test_slug', 'text': 'foobar'})
new_obj = SlugBasedModel.objects.get(slug='test_slug') new_obj = SlugBasedModel.objects.get(slug='test_slug')
self.assertEquals(new_obj.text, 'foobar') self.assertEquals(new_obj.text, 'foobar')
@ -311,7 +319,7 @@ class TestCreateModelWithAutoNowAddField(TestCase):
""" """
content = {'email': 'foobar@example.com', 'content': 'foobar'} content = {'email': 'foobar@example.com', 'content': 'foobar'}
request = factory.post('/', json.dumps(content), request = factory.post('/', json.dumps(content),
content_type='application/json') content_type='application/json')
response = self.view(request).render() response = self.view(request).render()
self.assertEquals(response.status_code, status.HTTP_201_CREATED) self.assertEquals(response.status_code, status.HTTP_201_CREATED)
created = self.objects.get(id=1) created = self.objects.get(id=1)

View File

@ -33,10 +33,10 @@ def not_found(request):
urlpatterns = patterns('', urlpatterns = patterns('',
url(r'^$', example), url(r'^$', example),
url(r'^permission_denied$', permission_denied), url(r'^permission_denied$', permission_denied),
url(r'^not_found$', not_found), url(r'^not_found$', not_found),
) )
class TemplateHTMLRendererTests(TestCase): class TemplateHTMLRendererTests(TestCase):

View File

@ -9,9 +9,11 @@ factory = RequestFactory()
class BlogPostCommentSerializer(serializers.ModelSerializer): class BlogPostCommentSerializer(serializers.ModelSerializer):
url = serializers.HyperlinkedIdentityField(view_name='blogpostcomment-detail') url = serializers.HyperlinkedIdentityField(
view_name='blogpostcomment-detail')
text = serializers.CharField() text = serializers.CharField()
blog_post_url = serializers.HyperlinkedRelatedField(source='blog_post', view_name='blogpost-detail') blog_post_url = serializers.HyperlinkedRelatedField(
source='blog_post', view_name='blogpost-detail')
class Meta: class Meta:
model = BlogPostComment model = BlogPostComment
@ -20,7 +22,8 @@ class BlogPostCommentSerializer(serializers.ModelSerializer):
class PhotoSerializer(serializers.Serializer): class PhotoSerializer(serializers.Serializer):
description = serializers.CharField() description = serializers.CharField()
album_url = serializers.HyperlinkedRelatedField(source='album', view_name='album-detail', album_url = serializers.HyperlinkedRelatedField(
source='album', view_name='album-detail',
queryset=Album.objects.all(), slug_field='title', slug_url_kwarg='title') queryset=Album.objects.all(), slug_field='title', slug_url_kwarg='title')
def restore_object(self, attrs, instance=None): def restore_object(self, attrs, instance=None):
@ -81,18 +84,33 @@ class OptionalRelationDetail(generics.RetrieveUpdateDestroyAPIView):
urlpatterns = patterns('', urlpatterns = patterns('',
url(r'^basic/$', BasicList.as_view(), name='basicmodel-list'), url(r'^basic/$',
url(r'^basic/(?P<pk>\d+)/$', BasicDetail.as_view(), name='basicmodel-detail'), BasicList.as_view(), name='basicmodel-list'),
url(r'^anchor/(?P<pk>\d+)/$', AnchorDetail.as_view(), name='anchor-detail'), url(r'^basic/(?P<pk>\d+)/$', BasicDetail.as_view(),
url(r'^manytomany/$', ManyToManyList.as_view(), name='manytomanymodel-list'), name='basicmodel-detail'),
url(r'^manytomany/(?P<pk>\d+)/$', ManyToManyDetail.as_view(), name='manytomanymodel-detail'), url(r'^anchor/(?P<pk>\d+)/$', AnchorDetail.as_view(),
url(r'^posts/(?P<pk>\d+)/$', BlogPostDetail.as_view(), name='blogpost-detail'), name='anchor-detail'),
url(r'^comments/$', BlogPostCommentListCreate.as_view(), name='blogpostcomment-list'), url(r'^manytomany/$', ManyToManyList.as_view(),
url(r'^comments/(?P<pk>\d+)/$', BlogPostCommentDetail.as_view(), name='blogpostcomment-detail'), name='manytomanymodel-list'),
url(r'^albums/(?P<title>\w[\w-]*)/$', AlbumDetail.as_view(), name='album-detail'), url(
url(r'^photos/$', PhotoListCreate.as_view(), name='photo-list'), r'^manytomany/(?P<pk>\d+)/$', ManyToManyDetail.as_view(
url(r'^optionalrelation/(?P<pk>\d+)/$', OptionalRelationDetail.as_view(), name='optionalrelationmodel-detail'), ),
) name='manytomanymodel-detail'),
url(r'^posts/(?P<pk>\d+)/$', BlogPostDetail.as_view(),
name='blogpost-detail'),
url(r'^comments/$', BlogPostCommentListCreate.as_view(),
name='blogpostcomment-list'),
url(
r'^comments/(?P<pk>\d+)/$', BlogPostCommentDetail.as_view(),
name='blogpostcomment-detail'),
url(
r'^albums/(?P<title>\w[\w-]*)/$', AlbumDetail.as_view(),
name='album-detail'),
url(r'^photos/$',
PhotoListCreate.as_view(), name='photo-list'),
url(r'^optionalrelation/(?P<pk>\d+)/$', OptionalRelationDetail.as_view(
), name='optionalrelationmodel-detail'),
)
class TestBasicHyperlinkedView(TestCase): class TestBasicHyperlinkedView(TestCase):
@ -107,8 +125,8 @@ class TestBasicHyperlinkedView(TestCase):
BasicModel(text=item).save() BasicModel(text=item).save()
self.objects = BasicModel.objects self.objects = BasicModel.objects
self.data = [ self.data = [
{'url': 'http://testserver/basic/%d/' % obj.id, 'text': obj.text} {'url': 'http://testserver/basic/%d/' % obj.id, 'text': obj.text}
for obj in self.objects.all() for obj in self.objects.all()
] ]
self.list_view = BasicList.as_view() self.list_view = BasicList.as_view()
self.detail_view = BasicDetail.as_view() self.detail_view = BasicDetail.as_view()
@ -151,12 +169,12 @@ class TestManyToManyHyperlinkedView(TestCase):
manytomany.rel.add(*anchors) manytomany.rel.add(*anchors)
self.data = [{ self.data = [{
'url': 'http://testserver/manytomany/1/', 'url': 'http://testserver/manytomany/1/',
'rel': [ 'rel': [
'http://testserver/anchor/1/', 'http://testserver/anchor/1/',
'http://testserver/anchor/2/', 'http://testserver/anchor/2/',
'http://testserver/anchor/3/', 'http://testserver/anchor/3/',
] ]
}] }]
self.list_view = ManyToManyList.as_view() self.list_view = ManyToManyList.as_view()
self.detail_view = ManyToManyDetail.as_view() self.detail_view = ManyToManyDetail.as_view()
@ -201,7 +219,8 @@ class TestCreateWithForeignKeys(TestCase):
self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(response['Location'], 'http://testserver/comments/1/') self.assertEqual(response['Location'], 'http://testserver/comments/1/')
self.assertEqual(self.post.blogpostcomment_set.count(), 1) self.assertEqual(self.post.blogpostcomment_set.count(), 1)
self.assertEqual(self.post.blogpostcomment_set.all()[0].text, 'A test comment') self.assertEqual(
self.post.blogpostcomment_set.all()[0].text, 'A test comment')
class TestCreateWithForeignKeysAndCustomSlug(TestCase): class TestCreateWithForeignKeysAndCustomSlug(TestCase):
@ -224,9 +243,10 @@ class TestCreateWithForeignKeysAndCustomSlug(TestCase):
response = self.list_create_view(request) response = self.list_create_view(request)
self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertNotIn('Location', response, self.assertNotIn('Location', response,
msg='Location should only be included if there is a "url" field on the serializer') msg='Location should only be included if there is a "url" field on the serializer')
self.assertEqual(self.post.photo_set.count(), 1) self.assertEqual(self.post.photo_set.count(), 1)
self.assertEqual(self.post.photo_set.all()[0].description, 'A test photo') self.assertEqual(
self.post.photo_set.all()[0].description, 'A test photo')
class TestOptionalRelationHyperlinkedView(TestCase): class TestOptionalRelationHyperlinkedView(TestCase):
@ -239,7 +259,8 @@ class TestOptionalRelationHyperlinkedView(TestCase):
OptionalRelationModel().save() OptionalRelationModel().save()
self.objects = OptionalRelationModel.objects self.objects = OptionalRelationModel.objects
self.detail_view = OptionalRelationDetail.as_view() self.detail_view = OptionalRelationDetail.as_view()
self.data = {"url": "http://testserver/optionalrelation/1/", "other": None} self.data = {"url":
"http://testserver/optionalrelation/1/", "other": None}
def test_get_detail_view(self): def test_get_detail_view(self):
""" """
@ -257,6 +278,6 @@ class TestOptionalRelationHyperlinkedView(TestCase):
should accept None for non existing relations. should accept None for non existing relations.
""" """
response = self.client.put('/optionalrelation/1/', response = self.client.put('/optionalrelation/1/',
data=json.dumps(self.data), data=json.dumps(self.data),
content_type='application/json') content_type='application/json')
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)

View File

@ -62,8 +62,8 @@ class IntegrationTestPagination(TestCase):
BasicModel(text=char * 3).save() BasicModel(text=char * 3).save()
self.objects = BasicModel.objects self.objects = BasicModel.objects
self.data = [ self.data = [
{'id': obj.id, 'text': obj.text} {'id': obj.id, 'text': obj.text}
for obj in self.objects.all() for obj in self.objects.all()
] ]
self.view = RootView.as_view() self.view = RootView.as_view()
@ -103,15 +103,17 @@ class IntegrationTestPaginationAndFiltering(TestCase):
""" """
base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8)) base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8))
for i in range(26): for i in range(26):
text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc. text = chr(i + ord(
base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc.
decimal = base_data[1] + i decimal = base_data[1] + i
date = base_data[2] - datetime.timedelta(days=i * 2) date = base_data[2] - datetime.timedelta(days=i * 2)
FilterableItem(text=text, decimal=decimal, date=date).save() FilterableItem(text=text, decimal=decimal, date=date).save()
self.objects = FilterableItem.objects self.objects = FilterableItem.objects
self.data = [ self.data = [
{'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal,
for obj in self.objects.all() 'date': obj.date}
for obj in self.objects.all()
] ]
self.view = FilterFieldsRootView.as_view() self.view = FilterFieldsRootView.as_view()
@ -180,7 +182,8 @@ class UnitTestPagination(TestCase):
""" """
Ensure context gets passed through to the object serializer. Ensure context gets passed through to the object serializer.
""" """
serializer = PassOnContextPaginationSerializer(self.first_page, context={'foo': 'bar'}) serializer = PassOnContextPaginationSerializer(
self.first_page, context={'foo': 'bar'})
serializer.data serializer.data
results = serializer.fields[serializer.results_field] results = serializer.fields[serializer.results_field]
self.assertEquals(serializer.context, results.context) self.assertEquals(serializer.context, results.context)
@ -199,8 +202,8 @@ class TestUnpaginated(TestCase):
BasicModel(text=i).save() BasicModel(text=i).save()
self.objects = BasicModel.objects self.objects = BasicModel.objects
self.data = [ self.data = [
{'id': obj.id, 'text': obj.text} {'id': obj.id, 'text': obj.text}
for obj in self.objects.all() for obj in self.objects.all()
] ]
self.view = DefaultPageSizeKwargView.as_view() self.view = DefaultPageSizeKwargView.as_view()
@ -227,8 +230,8 @@ class TestCustomPaginateByParam(TestCase):
BasicModel(text=i).save() BasicModel(text=i).save()
self.objects = BasicModel.objects self.objects = BasicModel.objects
self.data = [ self.data = [
{'id': obj.id, 'text': obj.text} {'id': obj.id, 'text': obj.text}
for obj in self.objects.all() for obj in self.objects.all()
] ]
self.view = PaginateByParamView.as_view() self.view = PaginateByParamView.as_view()
@ -254,7 +257,8 @@ class TestCustomPaginateByParam(TestCase):
class CustomField(serializers.Field): class CustomField(serializers.Field):
def to_native(self, value): def to_native(self, value):
if not 'view' in self.context: if not 'view' in self.context:
raise RuntimeError("context isn't getting passed into custom field") raise RuntimeError(
"context isn't getting passed into custom field")
return "value" return "value"
@ -277,4 +281,3 @@ class TestContextPassedToCustomField(TestCase):
response = self.view(request).render() response = self.view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK) self.assertEquals(response.status_code, status.HTTP_200_OK)

View File

@ -116,7 +116,7 @@
# def get_content_type(filename): # def get_content_type(filename):
# return mimetypes.guess_type(filename)[0] or 'application/octet-stream' # return mimetypes.guess_type(filename)[0] or 'application/octet-stream'
# #
#class TestMultiPartParser(TestCase): # class TestMultiPartParser(TestCase):
# def setUp(self): # def setUp(self):
# self.req = RequestFactory() # self.req = RequestFactory()
# self.content_type, self.body = encode_multipart_formdata([('key1', 'val1'), ('key1', 'val2')], # self.content_type, self.body = encode_multipart_formdata([('key1', 'val1'), ('key1', 'val2')],

View File

@ -18,16 +18,19 @@ class FieldTests(TestCase):
https://github.com/tomchristie/django-rest-framework/issues/446 https://github.com/tomchristie/django-rest-framework/issues/446
""" """
field = serializers.PrimaryKeyRelatedField(queryset=NullModel.objects.all()) field = serializers.PrimaryKeyRelatedField(
queryset=NullModel.objects.all())
self.assertRaises(serializers.ValidationError, field.from_native, '') self.assertRaises(serializers.ValidationError, field.from_native, '')
self.assertRaises(serializers.ValidationError, field.from_native, []) self.assertRaises(serializers.ValidationError, field.from_native, [])
def test_hyperlinked_related_field_with_empty_string(self): def test_hyperlinked_related_field_with_empty_string(self):
field = serializers.HyperlinkedRelatedField(queryset=NullModel.objects.all(), view_name='') field = serializers.HyperlinkedRelatedField(
queryset=NullModel.objects.all(), view_name='')
self.assertRaises(serializers.ValidationError, field.from_native, '') self.assertRaises(serializers.ValidationError, field.from_native, '')
self.assertRaises(serializers.ValidationError, field.from_native, []) self.assertRaises(serializers.ValidationError, field.from_native, [])
def test_slug_related_field_with_empty_string(self): def test_slug_related_field_with_empty_string(self):
field = serializers.SlugRelatedField(queryset=NullModel.objects.all(), slug_field='pk') field = serializers.SlugRelatedField(
queryset=NullModel.objects.all(), slug_field='pk')
self.assertRaises(serializers.ValidationError, field.from_native, '') self.assertRaises(serializers.ValidationError, field.from_native, '')
self.assertRaises(serializers.ValidationError, field.from_native, []) self.assertRaises(serializers.ValidationError, field.from_native, [])

View File

@ -8,18 +8,28 @@ def dummy_view(request, pk):
pass pass
urlpatterns = patterns('', urlpatterns = patterns('',
url(r'^manytomanysource/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanysource-detail'), url(r'^manytomanysource/(?P<pk>[0-9]+)/$', dummy_view,
url(r'^manytomanytarget/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanytarget-detail'), name='manytomanysource-detail'),
url(r'^foreignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeysource-detail'), url(r'^manytomanytarget/(?P<pk>[0-9]+)/$', dummy_view,
url(r'^foreignkeytarget/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeytarget-detail'), name='manytomanytarget-detail'),
url(r'^nullableforeignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableforeignkeysource-detail'), url(r'^foreignkeysource/(?P<pk>[0-9]+)/$', dummy_view,
url(r'^onetoonetarget/(?P<pk>[0-9]+)/$', dummy_view, name='onetoonetarget-detail'), name='foreignkeysource-detail'),
url(r'^nullableonetoonesource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableonetoonesource-detail'), url(r'^foreignkeytarget/(?P<pk>[0-9]+)/$', dummy_view,
) name='foreignkeytarget-detail'),
url(
r'^nullableforeignkeysource/(?P<pk>[0-9]+)/$', dummy_view,
name='nullableforeignkeysource-detail'),
url(r'^onetoonetarget/(?P<pk>[0-9]+)/$', dummy_view,
name='onetoonetarget-detail'),
url(
r'^nullableonetoonesource/(?P<pk>[0-9]+)/$', dummy_view,
name='nullableonetoonesource-detail'),
)
class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer): class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer):
sources = serializers.ManyHyperlinkedRelatedField(view_name='manytomanysource-detail') sources = serializers.ManyHyperlinkedRelatedField(
view_name='manytomanysource-detail')
class Meta: class Meta:
model = ManyToManyTarget model = ManyToManyTarget
@ -31,7 +41,8 @@ class ManyToManySourceSerializer(serializers.HyperlinkedModelSerializer):
class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer): class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer):
sources = serializers.ManyHyperlinkedRelatedField(view_name='foreignkeysource-detail') sources = serializers.ManyHyperlinkedRelatedField(
view_name='foreignkeysource-detail')
class Meta: class Meta:
model = ForeignKeyTarget model = ForeignKeyTarget
@ -50,7 +61,8 @@ class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer)
# OneToOne # OneToOne
class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer): class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer):
nullable_source = serializers.HyperlinkedRelatedField(view_name='nullableonetoonesource-detail') nullable_source = serializers.HyperlinkedRelatedField(
view_name='nullableonetoonesource-detail')
class Meta: class Meta:
model = OneToOneTarget model = OneToOneTarget
@ -74,7 +86,8 @@ class HyperlinkedManyToManyTests(TestCase):
queryset = ManyToManySource.objects.all() queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset) serializer = ManyToManySourceSerializer(queryset)
expected = [ expected = [
{'url': '/manytomanysource/1/', 'name': u'source-1', 'targets': ['/manytomanytarget/1/']}, {'url': '/manytomanysource/1/', 'name':
u'source-1', 'targets': ['/manytomanytarget/1/']},
{'url': '/manytomanysource/2/', 'name': u'source-2', {'url': '/manytomanysource/2/', 'name': u'source-2',
'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']}, 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']},
{'url': '/manytomanysource/3/', 'name': u'source-3', {'url': '/manytomanysource/3/', 'name': u'source-3',
@ -90,7 +103,8 @@ class HyperlinkedManyToManyTests(TestCase):
'sources': ['/manytomanysource/1/', '/manytomanysource/2/', '/manytomanysource/3/']}, 'sources': ['/manytomanysource/1/', '/manytomanysource/2/', '/manytomanysource/3/']},
{'url': '/manytomanytarget/2/', 'name': u'target-2', {'url': '/manytomanytarget/2/', 'name': u'target-2',
'sources': ['/manytomanysource/2/', '/manytomanysource/3/']}, 'sources': ['/manytomanysource/2/', '/manytomanysource/3/']},
{'url': '/manytomanytarget/3/', 'name': u'target-3', 'sources': ['/manytomanysource/3/']} {'url': '/manytomanytarget/3/', 'name':
u'target-3', 'sources': ['/manytomanysource/3/']}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -117,7 +131,8 @@ class HyperlinkedManyToManyTests(TestCase):
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_reverse_many_to_many_update(self): def test_reverse_many_to_many_update(self):
data = {'url': '/manytomanytarget/1/', 'name': u'target-1', 'sources': ['/manytomanysource/1/']} data = {'url': '/manytomanytarget/1/', 'name': u'target-1',
'sources': ['/manytomanysource/1/']}
instance = ManyToManyTarget.objects.get(pk=1) instance = ManyToManyTarget.objects.get(pk=1)
serializer = ManyToManyTargetSerializer(instance, data=data) serializer = ManyToManyTargetSerializer(instance, data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
@ -128,10 +143,12 @@ class HyperlinkedManyToManyTests(TestCase):
queryset = ManyToManyTarget.objects.all() queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset) serializer = ManyToManyTargetSerializer(queryset)
expected = [ expected = [
{'url': '/manytomanytarget/1/', 'name': u'target-1', 'sources': ['/manytomanysource/1/']}, {'url': '/manytomanytarget/1/', 'name':
u'target-1', 'sources': ['/manytomanysource/1/']},
{'url': '/manytomanytarget/2/', 'name': u'target-2', {'url': '/manytomanytarget/2/', 'name': u'target-2',
'sources': ['/manytomanysource/2/', '/manytomanysource/3/']}, 'sources': ['/manytomanysource/2/', '/manytomanysource/3/']},
{'url': '/manytomanytarget/3/', 'name': u'target-3', 'sources': ['/manytomanysource/3/']} {'url': '/manytomanytarget/3/', 'name':
u'target-3', 'sources': ['/manytomanysource/3/']}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -149,7 +166,8 @@ class HyperlinkedManyToManyTests(TestCase):
queryset = ManyToManySource.objects.all() queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset) serializer = ManyToManySourceSerializer(queryset)
expected = [ expected = [
{'url': '/manytomanysource/1/', 'name': u'source-1', 'targets': ['/manytomanytarget/1/']}, {'url': '/manytomanysource/1/', 'name':
u'source-1', 'targets': ['/manytomanytarget/1/']},
{'url': '/manytomanysource/2/', 'name': u'source-2', {'url': '/manytomanysource/2/', 'name': u'source-2',
'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']}, 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']},
{'url': '/manytomanysource/3/', 'name': u'source-3', {'url': '/manytomanysource/3/', 'name': u'source-3',
@ -176,7 +194,8 @@ class HyperlinkedManyToManyTests(TestCase):
'sources': ['/manytomanysource/1/', '/manytomanysource/2/', '/manytomanysource/3/']}, 'sources': ['/manytomanysource/1/', '/manytomanysource/2/', '/manytomanysource/3/']},
{'url': '/manytomanytarget/2/', 'name': u'target-2', {'url': '/manytomanytarget/2/', 'name': u'target-2',
'sources': ['/manytomanysource/2/', '/manytomanysource/3/']}, 'sources': ['/manytomanysource/2/', '/manytomanysource/3/']},
{'url': '/manytomanytarget/3/', 'name': u'target-3', 'sources': ['/manytomanysource/3/']}, {'url': '/manytomanytarget/3/', 'name':
u'target-3', 'sources': ['/manytomanysource/3/']},
{'url': '/manytomanytarget/4/', 'name': u'target-4', {'url': '/manytomanytarget/4/', 'name': u'target-4',
'sources': ['/manytomanysource/1/', '/manytomanysource/3/']} 'sources': ['/manytomanysource/1/', '/manytomanysource/3/']}
] ]
@ -199,9 +218,12 @@ class HyperlinkedForeignKeyTests(TestCase):
queryset = ForeignKeySource.objects.all() queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset) serializer = ForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'}, {'url': '/foreignkeysource/1/', 'name':
{'url': '/foreignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, u'source-1', 'target': '/foreignkeytarget/1/'},
{'url': '/foreignkeysource/3/', 'name': u'source-3', 'target': '/foreignkeytarget/1/'} {'url': '/foreignkeysource/2/', 'name':
u'source-2', 'target': '/foreignkeytarget/1/'},
{'url': '/foreignkeysource/3/', 'name':
u'source-3', 'target': '/foreignkeytarget/1/'}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -211,12 +233,14 @@ class HyperlinkedForeignKeyTests(TestCase):
expected = [ expected = [
{'url': '/foreignkeytarget/1/', 'name': u'target-1', {'url': '/foreignkeytarget/1/', 'name': u'target-1',
'sources': ['/foreignkeysource/1/', '/foreignkeysource/2/', '/foreignkeysource/3/']}, 'sources': ['/foreignkeysource/1/', '/foreignkeysource/2/', '/foreignkeysource/3/']},
{'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': []}, {'url': '/foreignkeytarget/2/', 'name':
u'target-2', 'sources': []},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_foreign_key_update(self): def test_foreign_key_update(self):
data = {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/2/'} data = {'url': '/foreignkeysource/1/', 'name': u'source-1',
'target': '/foreignkeytarget/2/'}
instance = ForeignKeySource.objects.get(pk=1) instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data) serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
@ -227,14 +251,18 @@ class HyperlinkedForeignKeyTests(TestCase):
queryset = ForeignKeySource.objects.all() queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset) serializer = ForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/2/'}, {'url': '/foreignkeysource/1/', 'name':
{'url': '/foreignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, u'source-1', 'target': '/foreignkeytarget/2/'},
{'url': '/foreignkeysource/3/', 'name': u'source-3', 'target': '/foreignkeytarget/1/'} {'url': '/foreignkeysource/2/', 'name':
u'source-2', 'target': '/foreignkeytarget/1/'},
{'url': '/foreignkeysource/3/', 'name':
u'source-3', 'target': '/foreignkeytarget/1/'}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_foreign_key_update_incorrect_type(self): def test_foreign_key_update_incorrect_type(self):
data = {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': 2} data = {'url': '/foreignkeysource/1/', 'name': u'source-1',
'target': 2}
instance = ForeignKeySource.objects.get(pk=1) instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data) serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
@ -253,7 +281,8 @@ class HyperlinkedForeignKeyTests(TestCase):
expected = [ expected = [
{'url': '/foreignkeytarget/1/', 'name': u'target-1', {'url': '/foreignkeytarget/1/', 'name': u'target-1',
'sources': ['/foreignkeysource/1/', '/foreignkeysource/2/', '/foreignkeysource/3/']}, 'sources': ['/foreignkeysource/1/', '/foreignkeysource/2/', '/foreignkeysource/3/']},
{'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': []}, {'url': '/foreignkeytarget/2/', 'name':
u'target-2', 'sources': []},
] ]
self.assertEquals(new_serializer.data, expected) self.assertEquals(new_serializer.data, expected)
@ -264,14 +293,16 @@ class HyperlinkedForeignKeyTests(TestCase):
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset) serializer = ForeignKeyTargetSerializer(queryset)
expected = [ expected = [
{'url': '/foreignkeytarget/1/', 'name': u'target-1', 'sources': ['/foreignkeysource/2/']}, {'url': '/foreignkeytarget/1/', 'name':
u'target-1', 'sources': ['/foreignkeysource/2/']},
{'url': '/foreignkeytarget/2/', 'name': u'target-2', {'url': '/foreignkeytarget/2/', 'name': u'target-2',
'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']}, 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_foreign_key_create(self): def test_foreign_key_create(self):
data = {'url': '/foreignkeysource/4/', 'name': u'source-4', 'target': '/foreignkeytarget/2/'} data = {'url': '/foreignkeysource/4/', 'name': u'source-4',
'target': '/foreignkeytarget/2/'}
serializer = ForeignKeySourceSerializer(data=data) serializer = ForeignKeySourceSerializer(data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
obj = serializer.save() obj = serializer.save()
@ -282,10 +313,14 @@ class HyperlinkedForeignKeyTests(TestCase):
queryset = ForeignKeySource.objects.all() queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset) serializer = ForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'}, {'url': '/foreignkeysource/1/', 'name':
{'url': '/foreignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, u'source-1', 'target': '/foreignkeytarget/1/'},
{'url': '/foreignkeysource/3/', 'name': u'source-3', 'target': '/foreignkeytarget/1/'}, {'url': '/foreignkeysource/2/', 'name':
{'url': '/foreignkeysource/4/', 'name': u'source-4', 'target': '/foreignkeytarget/2/'}, u'source-2', 'target': '/foreignkeytarget/1/'},
{'url': '/foreignkeysource/3/', 'name':
u'source-3', 'target': '/foreignkeytarget/1/'},
{'url': '/foreignkeysource/4/', 'name':
u'source-4', 'target': '/foreignkeytarget/2/'},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -302,19 +337,23 @@ class HyperlinkedForeignKeyTests(TestCase):
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset) serializer = ForeignKeyTargetSerializer(queryset)
expected = [ expected = [
{'url': '/foreignkeytarget/1/', 'name': u'target-1', 'sources': ['/foreignkeysource/2/']}, {'url': '/foreignkeytarget/1/', 'name':
{'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': []}, u'target-1', 'sources': ['/foreignkeysource/2/']},
{'url': '/foreignkeytarget/2/', 'name':
u'target-2', 'sources': []},
{'url': '/foreignkeytarget/3/', 'name': u'target-3', {'url': '/foreignkeytarget/3/', 'name': u'target-3',
'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']}, 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_foreign_key_update_with_invalid_null(self): def test_foreign_key_update_with_invalid_null(self):
data = {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': None} data = {'url': '/foreignkeysource/1/', 'name': u'source-1',
'target': None}
instance = ForeignKeySource.objects.get(pk=1) instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data) serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'target': [u'Value may not be null']}) self.assertEquals(
serializer.errors, {'target': [u'Value may not be null']})
class HyperlinkedNullableForeignKeyTests(TestCase): class HyperlinkedNullableForeignKeyTests(TestCase):
@ -326,21 +365,26 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
for idx in range(1, 4): for idx in range(1, 4):
if idx == 3: if idx == 3:
target = None target = None
source = NullableForeignKeySource(name='source-%d' % idx, target=target) source = NullableForeignKeySource(
name='source-%d' % idx, target=target)
source.save() source.save()
def test_foreign_key_retrieve_with_null(self): def test_foreign_key_retrieve_with_null(self):
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset) serializer = NullableForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'}, {'url': '/nullableforeignkeysource/1/', 'name':
{'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, u'source-1', 'target': '/foreignkeytarget/1/'},
{'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None}, {'url': '/nullableforeignkeysource/2/', 'name':
u'source-2', 'target': '/foreignkeytarget/1/'},
{'url': '/nullableforeignkeysource/3/', 'name':
u'source-3', 'target': None},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_foreign_key_create_with_valid_null(self): def test_foreign_key_create_with_valid_null(self):
data = {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None} data = {'url': '/nullableforeignkeysource/4/', 'name':
u'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data) serializer = NullableForeignKeySourceSerializer(data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
obj = serializer.save() obj = serializer.save()
@ -351,10 +395,14 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset) serializer = NullableForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'}, {'url': '/nullableforeignkeysource/1/', 'name':
{'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, u'source-1', 'target': '/foreignkeytarget/1/'},
{'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None}, {'url': '/nullableforeignkeysource/2/', 'name':
{'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None} u'source-2', 'target': '/foreignkeytarget/1/'},
{'url': '/nullableforeignkeysource/3/', 'name':
u'source-3', 'target': None},
{'url': '/nullableforeignkeysource/4/', 'name':
u'source-4', 'target': None}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -363,8 +411,10 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
The emptystring should be interpreted as null in the context The emptystring should be interpreted as null in the context
of relationships. of relationships.
""" """
data = {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': ''} data = {'url': '/nullableforeignkeysource/4/', 'name':
expected_data = {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None} u'source-4', 'target': ''}
expected_data = {'url': '/nullableforeignkeysource/4/',
'name': u'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data) serializer = NullableForeignKeySourceSerializer(data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
obj = serializer.save() obj = serializer.save()
@ -375,15 +425,20 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset) serializer = NullableForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'}, {'url': '/nullableforeignkeysource/1/', 'name':
{'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, u'source-1', 'target': '/foreignkeytarget/1/'},
{'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None}, {'url': '/nullableforeignkeysource/2/', 'name':
{'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None} u'source-2', 'target': '/foreignkeytarget/1/'},
{'url': '/nullableforeignkeysource/3/', 'name':
u'source-3', 'target': None},
{'url': '/nullableforeignkeysource/4/', 'name':
u'source-4', 'target': None}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_foreign_key_update_with_valid_null(self): def test_foreign_key_update_with_valid_null(self):
data = {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None} data = {'url': '/nullableforeignkeysource/1/', 'name':
u'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1) instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data) serializer = NullableForeignKeySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
@ -394,9 +449,12 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset) serializer = NullableForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None}, {'url': '/nullableforeignkeysource/1/', 'name':
{'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, u'source-1', 'target': None},
{'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None}, {'url': '/nullableforeignkeysource/2/', 'name':
u'source-2', 'target': '/foreignkeytarget/1/'},
{'url': '/nullableforeignkeysource/3/', 'name':
u'source-3', 'target': None},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -405,8 +463,10 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
The emptystring should be interpreted as null in the context The emptystring should be interpreted as null in the context
of relationships. of relationships.
""" """
data = {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': ''} data = {'url': '/nullableforeignkeysource/1/', 'name':
expected_data = {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None} u'source-1', 'target': ''}
expected_data = {'url': '/nullableforeignkeysource/1/',
'name': u'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1) instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data) serializer = NullableForeignKeySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
@ -417,9 +477,12 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset) serializer = NullableForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None}, {'url': '/nullableforeignkeysource/1/', 'name':
{'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, u'source-1', 'target': None},
{'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None}, {'url': '/nullableforeignkeysource/2/', 'name':
u'source-2', 'target': '/foreignkeytarget/1/'},
{'url': '/nullableforeignkeysource/3/', 'name':
u'source-3', 'target': None},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -460,7 +523,9 @@ class HyperlinkedNullableOneToOneTests(TestCase):
queryset = OneToOneTarget.objects.all() queryset = OneToOneTarget.objects.all()
serializer = NullableOneToOneTargetSerializer(queryset) serializer = NullableOneToOneTargetSerializer(queryset)
expected = [ expected = [
{'url': '/onetoonetarget/1/', 'name': u'target-1', 'nullable_source': '/nullableonetoonesource/1/'}, {'url': '/onetoonetarget/1/', 'name': u'target-1',
{'url': '/onetoonetarget/2/', 'name': u'target-2', 'nullable_source': None}, 'nullable_source': '/nullableonetoonesource/1/'},
{'url': '/onetoonetarget/2/', 'name': u'target-2',
'nullable_source': None},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)

View File

@ -53,9 +53,12 @@ class ReverseForeignKeyTests(TestCase):
queryset = ForeignKeySource.objects.all() queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset) serializer = ForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'source-1', 'target': {'id': 1, 'name': u'target-1'}}, {'id': 1, 'name': u'source-1', 'target': {'id': 1,
{'id': 2, 'name': u'source-2', 'target': {'id': 1, 'name': u'target-1'}}, 'name': u'target-1'}},
{'id': 3, 'name': u'source-3', 'target': {'id': 1, 'name': u'target-1'}}, {'id': 2, 'name': u'source-2', 'target': {'id': 1,
'name': u'target-1'}},
{'id': 3, 'name': u'source-3', 'target': {'id': 1,
'name': u'target-1'}},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -81,15 +84,18 @@ class NestedNullableForeignKeyTests(TestCase):
for idx in range(1, 4): for idx in range(1, 4):
if idx == 3: if idx == 3:
target = None target = None
source = NullableForeignKeySource(name='source-%d' % idx, target=target) source = NullableForeignKeySource(
name='source-%d' % idx, target=target)
source.save() source.save()
def test_foreign_key_retrieve_with_null(self): def test_foreign_key_retrieve_with_null(self):
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset) serializer = NullableForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'source-1', 'target': {'id': 1, 'name': u'target-1'}}, {'id': 1, 'name': u'source-1', 'target': {'id': 1,
{'id': 2, 'name': u'source-2', 'target': {'id': 1, 'name': u'target-1'}}, 'name': u'target-1'}},
{'id': 2, 'name': u'source-2', 'target': {'id': 1,
'name': u'target-1'}},
{'id': 3, 'name': u'source-3', 'target': None}, {'id': 3, 'name': u'source-3', 'target': None},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -108,7 +114,8 @@ class NestedNullableOneToOneTests(TestCase):
queryset = OneToOneTarget.objects.all() queryset = OneToOneTarget.objects.all()
serializer = NullableOneToOneTargetSerializer(queryset) serializer = NullableOneToOneTargetSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'target-1', 'nullable_source': {'id': 1, 'name': u'source-1', 'target': 1}}, {'id': 1, 'name': u'target-1', 'nullable_source': {
'id': 1, 'name': u'source-1', 'target': 1}},
{'id': 2, 'name': u'target-2', 'nullable_source': None}, {'id': 2, 'name': u'target-2', 'nullable_source': None},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)

View File

@ -270,7 +270,8 @@ class PKForeignKeyTests(TestCase):
instance = ForeignKeySource.objects.get(pk=1) instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data) serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'target': [u'Value may not be null']}) self.assertEquals(
serializer.errors, {'target': [u'Value may not be null']})
class PKNullableForeignKeyTests(TestCase): class PKNullableForeignKeyTests(TestCase):
@ -280,7 +281,8 @@ class PKNullableForeignKeyTests(TestCase):
for idx in range(1, 4): for idx in range(1, 4):
if idx == 3: if idx == 3:
target = None target = None
source = NullableForeignKeySource(name='source-%d' % idx, target=target) source = NullableForeignKeySource(
name='source-%d' % idx, target=target)
source.save() source.save()
def test_foreign_key_retrieve_with_null(self): def test_foreign_key_retrieve_with_null(self):

View File

@ -49,7 +49,8 @@ class PKForeignKeyTests(TestCase):
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset) serializer = ForeignKeyTargetSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, {'id': 1, 'name': u'target-1', 'sources': [
'source-1', 'source-2', 'source-3']},
{'id': 2, 'name': u'target-2', 'sources': []}, {'id': 2, 'name': u'target-2', 'sources': []},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -77,10 +78,12 @@ class PKForeignKeyTests(TestCase):
instance = ForeignKeySource.objects.get(pk=1) instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data) serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'target': [u'Object with name=123 does not exist.']}) self.assertEquals(serializer.errors, {'target': [
u'Object with name=123 does not exist.']})
def test_reverse_foreign_key_update(self): def test_reverse_foreign_key_update(self):
data = {'id': 2, 'name': u'target-2', 'sources': ['source-1', 'source-3']} data = {'id': 2, 'name': u'target-2', 'sources': [
'source-1', 'source-3']}
instance = ForeignKeyTarget.objects.get(pk=2) instance = ForeignKeyTarget.objects.get(pk=2)
serializer = ForeignKeyTargetSerializer(instance, data=data) serializer = ForeignKeyTargetSerializer(instance, data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
@ -89,7 +92,8 @@ class PKForeignKeyTests(TestCase):
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
new_serializer = ForeignKeyTargetSerializer(queryset) new_serializer = ForeignKeyTargetSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, {'id': 1, 'name': u'target-1', 'sources': [
'source-1', 'source-2', 'source-3']},
{'id': 2, 'name': u'target-2', 'sources': []}, {'id': 2, 'name': u'target-2', 'sources': []},
] ]
self.assertEquals(new_serializer.data, expected) self.assertEquals(new_serializer.data, expected)
@ -102,7 +106,8 @@ class PKForeignKeyTests(TestCase):
serializer = ForeignKeyTargetSerializer(queryset) serializer = ForeignKeyTargetSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'target-1', 'sources': ['source-2']}, {'id': 1, 'name': u'target-1', 'sources': ['source-2']},
{'id': 2, 'name': u'target-2', 'sources': ['source-1', 'source-3']}, {'id': 2, 'name': u'target-2', 'sources': [
'source-1', 'source-3']},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -127,7 +132,8 @@ class PKForeignKeyTests(TestCase):
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_reverse_foreign_key_create(self): def test_reverse_foreign_key_create(self):
data = {'id': 3, 'name': u'target-3', 'sources': ['source-1', 'source-3']} data = {'id': 3, 'name': u'target-3', 'sources': [
'source-1', 'source-3']}
serializer = ForeignKeyTargetSerializer(data=data) serializer = ForeignKeyTargetSerializer(data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
obj = serializer.save() obj = serializer.save()
@ -140,7 +146,8 @@ class PKForeignKeyTests(TestCase):
expected = [ expected = [
{'id': 1, 'name': u'target-1', 'sources': ['source-2']}, {'id': 1, 'name': u'target-1', 'sources': ['source-2']},
{'id': 2, 'name': u'target-2', 'sources': []}, {'id': 2, 'name': u'target-2', 'sources': []},
{'id': 3, 'name': u'target-3', 'sources': ['source-1', 'source-3']}, {'id': 3, 'name': u'target-3', 'sources': [
'source-1', 'source-3']},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -149,7 +156,8 @@ class PKForeignKeyTests(TestCase):
instance = ForeignKeySource.objects.get(pk=1) instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data) serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'target': [u'Value may not be null']}) self.assertEquals(
serializer.errors, {'target': [u'Value may not be null']})
class SlugNullableForeignKeyTests(TestCase): class SlugNullableForeignKeyTests(TestCase):
@ -159,7 +167,8 @@ class SlugNullableForeignKeyTests(TestCase):
for idx in range(1, 4): for idx in range(1, 4):
if idx == 3: if idx == 3:
target = None target = None
source = NullableForeignKeySource(name='source-%d' % idx, target=target) source = NullableForeignKeySource(
name='source-%d' % idx, target=target)
source.save() source.save()
def test_foreign_key_retrieve_with_null(self): def test_foreign_key_retrieve_with_null(self):

View File

@ -80,15 +80,20 @@ class HTMLView1(APIView):
return Response('text') return Response('text')
urlpatterns = patterns('', urlpatterns = patterns('',
url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB])), url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[
url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB])), RendererA, RendererB])),
url(r'^cache$', MockGETView.as_view()), url(r'^$', MockView.as_view(
url(r'^jsonp/jsonrenderer$', MockGETView.as_view(renderer_classes=[JSONRenderer, JSONPRenderer])), renderer_classes=[RendererA, RendererB])),
url(r'^jsonp/nojsonrenderer$', MockGETView.as_view(renderer_classes=[JSONPRenderer])), url(r'^cache$', MockGETView.as_view()),
url(r'^html$', HTMLView.as_view()), url(r'^jsonp/jsonrenderer$', MockGETView.as_view(renderer_classes=[
url(r'^html1$', HTMLView1.as_view()), JSONRenderer, JSONPRenderer])),
url(r'^api', include('rest_framework.urls', namespace='rest_framework')) url(r'^jsonp/nojsonrenderer$', MockGETView.as_view(
) renderer_classes=[JSONPRenderer])),
url(r'^html$', HTMLView.as_view()),
url(r'^html1$', HTMLView1.as_view()),
url(r'^api', include(
'rest_framework.urls', namespace='rest_framework'))
)
class POSTDeniedPermission(permissions.BasePermission): class POSTDeniedPermission(permissions.BasePermission):
@ -168,7 +173,7 @@ class RendererEndToEndTests(TestCase):
param = '?%s=%s' % ( param = '?%s=%s' % (
api_settings.URL_ACCEPT_OVERRIDE, api_settings.URL_ACCEPT_OVERRIDE,
RendererB.media_type RendererB.media_type
) )
resp = self.client.get('/' + param) resp = self.client.get('/' + param)
self.assertEquals(resp['Content-Type'], RendererB.media_type) self.assertEquals(resp['Content-Type'], RendererB.media_type)
self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
@ -185,7 +190,7 @@ class RendererEndToEndTests(TestCase):
param = '?%s=%s' % ( param = '?%s=%s' % (
api_settings.URL_FORMAT_OVERRIDE, api_settings.URL_FORMAT_OVERRIDE,
RendererB.format RendererB.format
) )
resp = self.client.get('/' + param) resp = self.client.get('/' + param)
self.assertEquals(resp['Content-Type'], RendererB.media_type) self.assertEquals(resp['Content-Type'], RendererB.media_type)
self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
@ -205,9 +210,9 @@ class RendererEndToEndTests(TestCase):
param = '?%s=%s' % ( param = '?%s=%s' % (
api_settings.URL_FORMAT_OVERRIDE, api_settings.URL_FORMAT_OVERRIDE,
RendererB.format RendererB.format
) )
resp = self.client.get('/' + param, resp = self.client.get('/' + param,
HTTP_ACCEPT=RendererB.media_type) HTTP_ACCEPT=RendererB.media_type)
self.assertEquals(resp['Content-Type'], RendererB.media_type) self.assertEquals(resp['Content-Type'], RendererB.media_type)
self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEquals(resp.status_code, DUMMYSTATUS) self.assertEquals(resp.status_code, DUMMYSTATUS)
@ -262,7 +267,7 @@ class JSONPRendererTests(TestCase):
Test JSONP rendering with View JSON Renderer. Test JSONP rendering with View JSON Renderer.
""" """
resp = self.client.get('/jsonp/jsonrenderer', resp = self.client.get('/jsonp/jsonrenderer',
HTTP_ACCEPT='application/javascript') HTTP_ACCEPT='application/javascript')
self.assertEquals(resp.status_code, status.HTTP_200_OK) self.assertEquals(resp.status_code, status.HTTP_200_OK)
self.assertEquals(resp['Content-Type'], 'application/javascript') self.assertEquals(resp['Content-Type'], 'application/javascript')
self.assertEquals(resp.content, 'callback(%s);' % _flat_repr) self.assertEquals(resp.content, 'callback(%s);' % _flat_repr)
@ -272,7 +277,7 @@ class JSONPRendererTests(TestCase):
Test JSONP rendering without View JSON Renderer. Test JSONP rendering without View JSON Renderer.
""" """
resp = self.client.get('/jsonp/nojsonrenderer', resp = self.client.get('/jsonp/nojsonrenderer',
HTTP_ACCEPT='application/javascript') HTTP_ACCEPT='application/javascript')
self.assertEquals(resp.status_code, status.HTTP_200_OK) self.assertEquals(resp.status_code, status.HTTP_200_OK)
self.assertEquals(resp['Content-Type'], 'application/javascript') self.assertEquals(resp['Content-Type'], 'application/javascript')
self.assertEquals(resp.content, 'callback(%s);' % _flat_repr) self.assertEquals(resp.content, 'callback(%s);' % _flat_repr)
@ -282,11 +287,13 @@ class JSONPRendererTests(TestCase):
Test JSONP rendering with callback function name. Test JSONP rendering with callback function name.
""" """
callback_func = 'myjsonpcallback' callback_func = 'myjsonpcallback'
resp = self.client.get('/jsonp/nojsonrenderer?callback=' + callback_func, resp = self.client.get(
'/jsonp/nojsonrenderer?callback=' + callback_func,
HTTP_ACCEPT='application/javascript') HTTP_ACCEPT='application/javascript')
self.assertEquals(resp.status_code, status.HTTP_200_OK) self.assertEquals(resp.status_code, status.HTTP_200_OK)
self.assertEquals(resp['Content-Type'], 'application/javascript') self.assertEquals(resp['Content-Type'], 'application/javascript')
self.assertEquals(resp.content, '%s(%s);' % (callback_func, _flat_repr)) self.assertEquals(
resp.content, '%s(%s);' % (callback_func, _flat_repr))
if yaml: if yaml:
@ -380,7 +387,8 @@ class XMLRendererTestCase(TestCase):
Test XML rendering. Test XML rendering.
""" """
renderer = XMLRenderer() renderer = XMLRenderer()
content = renderer.render({'field': Decimal('111.2')}, 'application/xml') content = renderer.render(
{'field': Decimal('111.2')}, 'application/xml')
self.assertXMLContains(content, '<field>111.2</field>') self.assertXMLContains(content, '<field>111.2</field>')
def test_render_none(self): def test_render_none(self):
@ -405,15 +413,18 @@ class XMLRendererTestCase(TestCase):
Test XML rendering. Test XML rendering.
""" """
renderer = XMLRenderer() renderer = XMLRenderer()
content = StringIO(renderer.render(self._complex_data, 'application/xml')) content = StringIO(
renderer.render(self._complex_data, 'application/xml'))
parser = XMLParser() parser = XMLParser()
complex_data_out = parser.parse(content) complex_data_out = parser.parse(content)
error_msg = "complex data differs!IN:\n %s \n\n OUT:\n %s" % (repr(self._complex_data), repr(complex_data_out)) error_msg = "complex data differs!IN:\n %s \n\n OUT:\n %s" % (
repr(self._complex_data), repr(complex_data_out))
self.assertEqual(self._complex_data, complex_data_out, error_msg) self.assertEqual(self._complex_data, complex_data_out, error_msg)
def assertXMLContains(self, xml, string): def assertXMLContains(self, xml, string):
self.assertTrue(xml.startswith('<?xml version="1.0" encoding="utf-8"?>\n<root>')) self.assertTrue(
xml.startswith('<?xml version="1.0" encoding="utf-8"?>\n<root>'))
self.assertTrue(xml.endswith('</root>')) self.assertTrue(xml.endswith('</root>'))
self.assertTrue(string in xml, '%r not in %r' % (string, xml)) self.assertTrue(string in xml, '%r not in %r' % (string, xml))

View File

@ -15,7 +15,7 @@ from rest_framework.parsers import (
FormParser, FormParser,
MultiPartParser, MultiPartParser,
JSONParser JSONParser
) )
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
@ -53,7 +53,8 @@ class TestMethodOverloading(TestCase):
POST requests can be overloaded to another method by setting a POST requests can be overloaded to another method by setting a
reserved form field reserved form field
""" """
request = Request(factory.post('/', {api_settings.FORM_METHOD_OVERRIDE: 'DELETE'})) request = Request(
factory.post('/', {api_settings.FORM_METHOD_OVERRIDE: 'DELETE'}))
self.assertEqual(request.method, 'DELETE') self.assertEqual(request.method, 'DELETE')
@ -88,7 +89,8 @@ class TestContentParsing(TestCase):
""" """
content = 'qwerty' content = 'qwerty'
content_type = 'text/plain' content_type = 'text/plain'
request = Request(factory.post('/', content, content_type=content_type)) request = Request(
factory.post('/', content, content_type=content_type))
request.parsers = (PlainTextParser(),) request.parsers = (PlainTextParser(),)
self.assertEqual(request.DATA, content) self.assertEqual(request.DATA, content)
@ -112,8 +114,9 @@ class TestContentParsing(TestCase):
if VERSION >= (1, 5): if VERSION >= (1, 5):
from django.test.client import MULTIPART_CONTENT, BOUNDARY, encode_multipart from django.test.client import MULTIPART_CONTENT, BOUNDARY, encode_multipart
request = Request(factory.put('/', encode_multipart(BOUNDARY, data), request = Request(
content_type=MULTIPART_CONTENT)) factory.put('/', encode_multipart(BOUNDARY, data),
content_type=MULTIPART_CONTENT))
else: else:
request = Request(factory.put('/', data)) request = Request(factory.put('/', data))
@ -240,8 +243,8 @@ class MockView(APIView):
return Response(status=status.INTERNAL_SERVER_ERROR) return Response(status=status.INTERNAL_SERVER_ERROR)
urlpatterns = patterns('', urlpatterns = patterns('',
(r'^$', MockView.as_view()), (r'^$', MockView.as_view()),
) )
class TestContentParsingWithAuthentication(TestCase): class TestContentParsingWithAuthentication(TestCase):
@ -252,7 +255,8 @@ class TestContentParsingWithAuthentication(TestCase):
self.username = 'john' self.username = 'john'
self.email = 'lennon@thebeatles.com' self.email = 'lennon@thebeatles.com'
self.password = 'password' self.password = 'password'
self.user = User.objects.create_user(self.username, self.email, self.password) self.user = User.objects.create_user(
self.username, self.email, self.password)
def test_user_logged_in_authentication_has_POST_when_not_logged_in(self): def test_user_logged_in_authentication_has_POST_when_not_logged_in(self):
""" """
@ -274,10 +278,12 @@ class TestContentParsingWithAuthentication(TestCase):
# content = {'example': 'example'} # content = {'example': 'example'}
# response = self.client.post('/', content) # response = self.client.post('/', content)
# self.assertEqual(status.OK, response.status_code, "POST data is malformed") # self.assertEqual(status.OK, response.status_code, "POST data is
# malformed")
# response = self.csrf_client.post('/', content) # response = self.csrf_client.post('/', content)
# self.assertEqual(status.OK, response.status_code, "POST data is malformed") # self.assertEqual(status.OK, response.status_code, "POST data is
# malformed")
class TestUserSetter(TestCase): class TestUserSetter(TestCase):

View File

@ -7,7 +7,7 @@ from rest_framework.renderers import (
BaseRenderer, BaseRenderer,
JSONRenderer, JSONRenderer,
BrowsableAPIRenderer BrowsableAPIRenderer
) )
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
@ -64,15 +64,19 @@ class HTMLView1(APIView):
urlpatterns = patterns('', urlpatterns = patterns('',
url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB])), url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[
url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB])), RendererA, RendererB])),
url(r'^html$', HTMLView.as_view()), url(r'^$', MockView.as_view(
url(r'^html1$', HTMLView1.as_view()), renderer_classes=[RendererA, RendererB])),
url(r'^restframework', include('rest_framework.urls', namespace='rest_framework')) url(r'^html$', HTMLView.as_view()),
) url(r'^html1$', HTMLView1.as_view()),
url(r'^restframework', include('rest_framework.urls',
namespace='rest_framework'))
)
# TODO: Clean tests bellow - remove duplicates with above, better unit testing, ... # TODO: Clean tests bellow - remove duplicates with above, better unit
# testing, ...
class RendererIntegrationTests(TestCase): class RendererIntegrationTests(TestCase):
""" """
End-to-end testing of renderers using an ResponseMixin on a generic view. End-to-end testing of renderers using an ResponseMixin on a generic view.
@ -122,7 +126,7 @@ class RendererIntegrationTests(TestCase):
param = '?%s=%s' % ( param = '?%s=%s' % (
api_settings.URL_ACCEPT_OVERRIDE, api_settings.URL_ACCEPT_OVERRIDE,
RendererB.media_type RendererB.media_type
) )
resp = self.client.get('/' + param) resp = self.client.get('/' + param)
self.assertEquals(resp['Content-Type'], RendererB.media_type) self.assertEquals(resp['Content-Type'], RendererB.media_type)
self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
@ -148,7 +152,7 @@ class RendererIntegrationTests(TestCase):
"""If both a 'format' query and a matching Accept header specified, """If both a 'format' query and a matching Accept header specified,
the renderer with the matching format attribute should serialize the response.""" the renderer with the matching format attribute should serialize the response."""
resp = self.client.get('/?format=%s' % RendererB.format, resp = self.client.get('/?format=%s' % RendererB.format,
HTTP_ACCEPT=RendererB.media_type) HTTP_ACCEPT=RendererB.media_type)
self.assertEquals(resp['Content-Type'], RendererB.media_type) self.assertEquals(resp['Content-Type'], RendererB.media_type)
self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) self.assertEquals(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEquals(resp.status_code, DUMMYSTATUS) self.assertEquals(resp.status_code, DUMMYSTATUS)

View File

@ -10,8 +10,8 @@ def null_view(request):
pass pass
urlpatterns = patterns('', urlpatterns = patterns('',
url(r'^view$', null_view, name='view'), url(r'^view$', null_view, name='view'),
) )
class ReverseTests(TestCase): class ReverseTests(TestCase):

View File

@ -2,9 +2,10 @@ import datetime
import pickle import pickle
from django.test import TestCase from django.test import TestCase
from rest_framework import serializers from rest_framework import serializers
from rest_framework.tests.models import (HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel, from rest_framework.tests.models import (
BlankFieldModel, BlogPost, Book, CallableDefaultValueModel, DefaultValueModel, HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel,
ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo) BlankFieldModel, BlogPost, Book, CallableDefaultValueModel, DefaultValueModel,
ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo)
class SubComment(object): class SubComment(object):
@ -117,7 +118,8 @@ class BasicTests(TestCase):
self.assertEquals(serializer.is_valid(), True) self.assertEquals(serializer.is_valid(), True)
self.assertEquals(serializer.object, expected) self.assertEquals(serializer.object, expected)
self.assertFalse(serializer.object is expected) self.assertFalse(serializer.object is expected)
self.assertEquals(serializer.data['sub_comment'], 'And Merry Christmas!') self.assertEquals(
serializer.data['sub_comment'], 'And Merry Christmas!')
def test_update(self): def test_update(self):
serializer = CommentSerializer(self.comment, data=self.data) serializer = CommentSerializer(self.comment, data=self.data)
@ -125,14 +127,16 @@ class BasicTests(TestCase):
self.assertEquals(serializer.is_valid(), True) self.assertEquals(serializer.is_valid(), True)
self.assertEquals(serializer.object, expected) self.assertEquals(serializer.object, expected)
self.assertTrue(serializer.object is expected) self.assertTrue(serializer.object is expected)
self.assertEquals(serializer.data['sub_comment'], 'And Merry Christmas!') self.assertEquals(
serializer.data['sub_comment'], 'And Merry Christmas!')
def test_partial_update(self): def test_partial_update(self):
msg = 'Merry New Year!' msg = 'Merry New Year!'
partial_data = {'content': msg} partial_data = {'content': msg}
serializer = CommentSerializer(self.comment, data=partial_data) serializer = CommentSerializer(self.comment, data=partial_data)
self.assertEquals(serializer.is_valid(), False) self.assertEquals(serializer.is_valid(), False)
serializer = CommentSerializer(self.comment, data=partial_data, partial=True) serializer = CommentSerializer(
self.comment, data=partial_data, partial=True)
expected = self.comment expected = self.comment
self.assertEqual(serializer.is_valid(), True) self.assertEqual(serializer.is_valid(), True)
self.assertEquals(serializer.object, expected) self.assertEquals(serializer.object, expected)
@ -146,7 +150,7 @@ class BasicTests(TestCase):
""" """
serializer = PersonSerializer(self.person) serializer = PersonSerializer(self.person)
self.assertEquals(set(serializer.data.keys()), self.assertEquals(set(serializer.data.keys()),
set(['name', 'age', 'info'])) set(['name', 'age', 'info']))
def test_field_with_dictionary(self): def test_field_with_dictionary(self):
""" """
@ -160,7 +164,8 @@ class BasicTests(TestCase):
""" """
Attempting to update fields set as read_only should have no effect. Attempting to update fields set as read_only should have no effect.
""" """
serializer = PersonSerializer(self.person, data={'name': 'dwight', 'age': 99}) serializer = PersonSerializer(
self.person, data={'name': 'dwight', 'age': 99})
self.assertEquals(serializer.is_valid(), True) self.assertEquals(serializer.is_valid(), True)
instance = serializer.save() instance = serializer.save()
self.assertEquals(serializer.errors, {}) self.assertEquals(serializer.errors, {})
@ -186,13 +191,13 @@ class ValidationTests(TestCase):
serializer = CommentSerializer(data=self.data) serializer = CommentSerializer(data=self.data)
self.assertEquals(serializer.is_valid(), False) self.assertEquals(serializer.is_valid(), False)
self.assertEquals(serializer.errors, self.assertEquals(serializer.errors,
{'content': [u'Ensure this value has at most 1000 characters (it has 1001).']}) {'content': [u'Ensure this value has at most 1000 characters (it has 1001).']})
def test_update(self): def test_update(self):
serializer = CommentSerializer(self.comment, data=self.data) serializer = CommentSerializer(self.comment, data=self.data)
self.assertEquals(serializer.is_valid(), False) self.assertEquals(serializer.is_valid(), False)
self.assertEquals(serializer.errors, self.assertEquals(serializer.errors,
{'content': [u'Ensure this value has at most 1000 characters (it has 1001).']}) {'content': [u'Ensure this value has at most 1000 characters (it has 1001).']})
def test_update_missing_field(self): def test_update_missing_field(self):
data = { data = {
@ -201,14 +206,15 @@ class ValidationTests(TestCase):
} }
serializer = CommentSerializer(self.comment, data=data) serializer = CommentSerializer(self.comment, data=data)
self.assertEquals(serializer.is_valid(), False) self.assertEquals(serializer.is_valid(), False)
self.assertEquals(serializer.errors, {'email': [u'This field is required.']}) self.assertEquals(
serializer.errors, {'email': [u'This field is required.']})
def test_missing_bool_with_default(self): def test_missing_bool_with_default(self):
"""Make sure that a boolean value with a 'False' value is not """Make sure that a boolean value with a 'False' value is not
mistaken for not having a default.""" mistaken for not having a default."""
data = { data = {
'title': 'Some action item', 'title': 'Some action item',
#No 'done' value. # No 'done' value.
} }
serializer = ActionItemSerializer(self.actionitem, data=data) serializer = ActionItemSerializer(self.actionitem, data=data)
self.assertEquals(serializer.is_valid(), True) self.assertEquals(serializer.is_valid(), True)
@ -221,23 +227,27 @@ class ValidationTests(TestCase):
data = ['i am', 'a', 'list'] data = ['i am', 'a', 'list']
serializer = CommentSerializer(self.comment, data=data) serializer = CommentSerializer(self.comment, data=data)
self.assertEquals(serializer.is_valid(), False) self.assertEquals(serializer.is_valid(), False)
self.assertEquals(serializer.errors, {'non_field_errors': [u'Invalid data']}) self.assertEquals(
serializer.errors, {'non_field_errors': [u'Invalid data']})
data = 'and i am a string' data = 'and i am a string'
serializer = CommentSerializer(self.comment, data=data) serializer = CommentSerializer(self.comment, data=data)
self.assertEquals(serializer.is_valid(), False) self.assertEquals(serializer.is_valid(), False)
self.assertEquals(serializer.errors, {'non_field_errors': [u'Invalid data']}) self.assertEquals(
serializer.errors, {'non_field_errors': [u'Invalid data']})
data = 42 data = 42
serializer = CommentSerializer(self.comment, data=data) serializer = CommentSerializer(self.comment, data=data)
self.assertEquals(serializer.is_valid(), False) self.assertEquals(serializer.is_valid(), False)
self.assertEquals(serializer.errors, {'non_field_errors': [u'Invalid data']}) self.assertEquals(
serializer.errors, {'non_field_errors': [u'Invalid data']})
def test_cross_field_validation(self): def test_cross_field_validation(self):
class CommentSerializerWithCrossFieldValidator(CommentSerializer): class CommentSerializerWithCrossFieldValidator(CommentSerializer):
def validate(self, attrs): def validate(self, attrs):
if attrs["email"] not in attrs["content"]: if attrs["email"] not in attrs["content"]:
raise serializers.ValidationError("Email address not in content") raise serializers.ValidationError(
"Email address not in content")
return attrs return attrs
data = { data = {
@ -253,7 +263,8 @@ class ValidationTests(TestCase):
serializer = CommentSerializerWithCrossFieldValidator(data=data) serializer = CommentSerializerWithCrossFieldValidator(data=data)
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'non_field_errors': [u'Email address not in content']}) self.assertEquals(serializer.errors, {'non_field_errors': [
u'Email address not in content']})
def test_null_is_true_fields(self): def test_null_is_true_fields(self):
""" """
@ -308,7 +319,8 @@ class CustomValidationTests(TestCase):
serializer = self.CommentSerializerWithFieldValidator(data=data) serializer = self.CommentSerializerWithFieldValidator(data=data)
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'content': [u'Test not in value']}) self.assertEquals(
serializer.errors, {'content': [u'Test not in value']})
def test_missing_data(self): def test_missing_data(self):
""" """
@ -318,9 +330,11 @@ class CustomValidationTests(TestCase):
'email': 'tom@example.com', 'email': 'tom@example.com',
'created': datetime.datetime(2012, 1, 1) 'created': datetime.datetime(2012, 1, 1)
} }
serializer = self.CommentSerializerWithFieldValidator(data=incomplete_data) serializer = self.CommentSerializerWithFieldValidator(
data=incomplete_data)
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'content': [u'This field is required.']}) self.assertEquals(
serializer.errors, {'content': [u'This field is required.']})
def test_wrong_data(self): def test_wrong_data(self):
""" """
@ -333,7 +347,8 @@ class CustomValidationTests(TestCase):
} }
serializer = self.CommentSerializerWithFieldValidator(data=wrong_data) serializer = self.CommentSerializerWithFieldValidator(data=wrong_data)
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'email': [u'Enter a valid e-mail address.']}) self.assertEquals(
serializer.errors, {'email': [u'Enter a valid e-mail address.']})
class PositiveIntegerAsChoiceTests(TestCase): class PositiveIntegerAsChoiceTests(TestCase):
@ -353,7 +368,8 @@ class ModelValidationTests(TestCase):
serializer.save() serializer.save()
second_serializer = AlbumsSerializer(data={'title': 'a'}) second_serializer = AlbumsSerializer(data={'title': 'a'})
self.assertFalse(second_serializer.is_valid()) self.assertFalse(second_serializer.is_valid())
self.assertEqual(second_serializer.errors, {'title': [u'Album with this Title already exists.']}) self.assertEqual(second_serializer.errors, {'title': [
u'Album with this Title already exists.']})
def test_foreign_key_with_partial(self): def test_foreign_key_with_partial(self):
""" """
@ -369,19 +385,21 @@ class ModelValidationTests(TestCase):
class Meta: class Meta:
model = Photo model = Photo
photo_serializer = PhotoSerializer(data={'description': 'test', 'album': album.pk}) photo_serializer = PhotoSerializer(
data={'description': 'test', 'album': album.pk})
self.assertTrue(photo_serializer.is_valid()) self.assertTrue(photo_serializer.is_valid())
photo = photo_serializer.save() photo = photo_serializer.save()
# Updating only the album (foreign key) # Updating only the album (foreign key)
photo_serializer = PhotoSerializer(instance=photo, data={'album': album.pk}, partial=True) photo_serializer = PhotoSerializer(
instance=photo, data={'album': album.pk}, partial=True)
self.assertTrue(photo_serializer.is_valid()) self.assertTrue(photo_serializer.is_valid())
self.assertTrue(photo_serializer.save()) self.assertTrue(photo_serializer.save())
# Updating only the description # Updating only the description
photo_serializer = PhotoSerializer(instance=photo, photo_serializer = PhotoSerializer(instance=photo,
data={'description': 'new'}, data={'description': 'new'},
partial=True) partial=True)
self.assertTrue(photo_serializer.is_valid()) self.assertTrue(photo_serializer.is_valid())
self.assertTrue(photo_serializer.save()) self.assertTrue(photo_serializer.save())
@ -391,15 +409,18 @@ class RegexValidationTest(TestCase):
def test_create_failed(self): def test_create_failed(self):
serializer = BookSerializer(data={'isbn': '1234567890'}) serializer = BookSerializer(data={'isbn': '1234567890'})
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'isbn': [u'isbn has to be exact 13 numbers']}) self.assertEquals(serializer.errors, {'isbn': [
u'isbn has to be exact 13 numbers']})
serializer = BookSerializer(data={'isbn': '12345678901234'}) serializer = BookSerializer(data={'isbn': '12345678901234'})
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'isbn': [u'isbn has to be exact 13 numbers']}) self.assertEquals(serializer.errors, {'isbn': [
u'isbn has to be exact 13 numbers']})
serializer = BookSerializer(data={'isbn': 'abcdefghijklm'}) serializer = BookSerializer(data={'isbn': 'abcdefghijklm'})
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'isbn': [u'isbn has to be exact 13 numbers']}) self.assertEquals(serializer.errors, {'isbn': [
u'isbn has to be exact 13 numbers']})
def test_create_success(self): def test_create_success(self):
serializer = BookSerializer(data={'isbn': '1234567890123'}) serializer = BookSerializer(data={'isbn': '1234567890123'})
@ -415,7 +436,8 @@ class MetadataTests(TestCase):
'created': serializers.DateTimeField 'created': serializers.DateTimeField
} }
for field_name, field in expected.items(): for field_name, field in expected.items():
self.assertTrue(isinstance(serializer.data.fields[field_name], field)) self.assertTrue(
isinstance(serializer.data.fields[field_name], field))
class ManyToManyTests(TestCase): class ManyToManyTests(TestCase):
@ -605,7 +627,8 @@ class DefaultValueTests(TestCase):
instance = serializer.save() instance = serializer.save()
data = {'extra': 'extra_value'} data = {'extra': 'extra_value'}
serializer = self.serializer_class(instance=instance, data=data, partial=True) serializer = self.serializer_class(
instance=instance, data=data, partial=True)
self.assertEquals(serializer.is_valid(), True) self.assertEquals(serializer.is_valid(), True)
instance = serializer.save() instance = serializer.save()
@ -674,7 +697,8 @@ class ManyRelatedTests(TestCase):
class BlogPostSerializer(serializers.Serializer): class BlogPostSerializer(serializers.Serializer):
title = serializers.CharField() title = serializers.CharField()
first_comment = BlogPostCommentSerializer(source='get_first_comment') first_comment = BlogPostCommentSerializer(
source='get_first_comment')
serializer = BlogPostSerializer(post) serializer = BlogPostSerializer(post)
@ -715,11 +739,11 @@ class RelatedTraversalTest(TestCase):
expected = { expected = {
'title': u'Test blog post', 'title': u'Test blog post',
'comments': [{ 'comments': [{
'text': u'I love this blog post', 'text': u'I love this blog post',
'post_owner': { 'post_owner': {
"name": u"django", "name": u"django",
"age": None "age": None
} }
}] }]
} }
@ -814,7 +838,7 @@ class BlankFieldTests(TestCase):
self.assertEquals(serializer.is_valid(), True) self.assertEquals(serializer.is_valid(), True)
#test for issue #460 # test for issue #460
class SerializerPickleTests(TestCase): class SerializerPickleTests(TestCase):
""" """
Test pickleability of the output of Serializers Test pickleability of the output of Serializers
@ -926,5 +950,7 @@ class NestedSerializerContextTests(TestCase):
album_collection = AlbumCollection() album_collection = AlbumCollection()
album_collection.albums = [album1, album2] album_collection.albums = [album1, album2]
# This will raise RuntimeError if context doesn't get passed correctly to the nested Serializers # This will raise RuntimeError if context doesn't get passed correctly
AlbumCollectionSerializer(album_collection, context={'context_item': 'album context'}).data # to the nested Serializers
AlbumCollectionSerializer(
album_collection, context={'context_item': 'album context'}).data

View File

@ -23,7 +23,7 @@ class TestSettingsManager(object):
def set(self, **kwargs): def set(self, **kwargs):
for k, v in kwargs.iteritems(): for k, v in kwargs.iteritems():
self._original_settings.setdefault(k, getattr(settings, k, self._original_settings.setdefault(k, getattr(settings, k,
NO_SETTING)) NO_SETTING))
setattr(settings, k, v) setattr(settings, k, v)
if 'INSTALLED_APPS' in kwargs: if 'INSTALLED_APPS' in kwargs:
self.syncdb() self.syncdb()
@ -63,5 +63,6 @@ class SettingsTestCase(TestCase):
class TestModelsTestCase(SettingsTestCase): class TestModelsTestCase(SettingsTestCase):
def setUp(self, *args, **kwargs): def setUp(self, *args, **kwargs):
installed_apps = tuple(settings.INSTALLED_APPS) + ('rest_framework.tests',) installed_apps = tuple(
settings.INSTALLED_APPS) + ('rest_framework.tests',)
self.settings_manager.set(INSTALLED_APPS=installed_apps) self.settings_manager.set(INSTALLED_APPS=installed_apps)

View File

@ -113,32 +113,34 @@ class ThrottlingTests(TestCase):
Ensure for second based throttles. Ensure for second based throttles.
""" """
self.ensure_response_header_contains_proper_throttle_field(MockView, self.ensure_response_header_contains_proper_throttle_field(MockView,
((0, None), ((0, None),
(0, None), (0, None),
(0, None), (0, None),
(0, '1') (0, '1')
)) ))
def test_minutes_fields(self): def test_minutes_fields(self):
""" """
Ensure for minute based throttles. Ensure for minute based throttles.
""" """
self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling, self.ensure_response_header_contains_proper_throttle_field(
MockView_MinuteThrottling,
((0, None), ((0, None),
(0, None), (0, None),
(0, None), (0, None),
(0, '60') (0, '60')
)) ))
def test_next_rate_remains_constant_if_followed(self): def test_next_rate_remains_constant_if_followed(self):
""" """
If a client follows the recommended next request rate, If a client follows the recommended next request rate,
the throttling rate should stay constant. the throttling rate should stay constant.
""" """
self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling, self.ensure_response_header_contains_proper_throttle_field(
MockView_MinuteThrottling,
((0, None), ((0, None),
(20, None), (20, None),
(40, None), (40, None),
(60, None), (60, None),
(80, None) (80, None)
)) ))

View File

@ -32,7 +32,8 @@ class FormatSuffixTests(TestCase):
for test_path in test_paths: for test_path in test_paths:
request = factory.get(test_path.path) request = factory.get(test_path.path)
try: try:
callback, callback_args, callback_kwargs = resolver.resolve(request.path_info) callback, callback_args, callback_kwargs = resolver.resolve(
request.path_info)
except: except:
self.fail("Failed to resolve URL: %s" % request.path_info) self.fail("Failed to resolve URL: %s" % request.path_info)
self.assertEquals(callback_args, test_path.args) self.assertEquals(callback_args, test_path.args)
@ -74,6 +75,7 @@ class FormatSuffixTests(TestCase):
test_paths = [ test_paths = [
URLTestPath('/test/path', (), {'foo': 'bar', }), URLTestPath('/test/path', (), {'foo': 'bar', }),
URLTestPath('/test/path.api', (), {'foo': 'bar', 'format': 'api'}), URLTestPath('/test/path.api', (), {'foo': 'bar', 'format': 'api'}),
URLTestPath('/test/path.asdf', (), {'foo': 'bar', 'format': 'asdf'}), URLTestPath(
'/test/path.asdf', (), {'foo': 'bar', 'format': 'asdf'}),
] ]
self._resolve_urlpatterns(urlpatterns, test_paths) self._resolve_urlpatterns(urlpatterns, test_paths)

View File

@ -19,7 +19,8 @@
# view = MockView() # view = MockView()
# content = {'qwerty': 'uiop'} # content = {'qwerty': 'uiop'}
# self.assertEqual(FormResource(view).validate_request(content, None), content) # self.assertEqual(FormResource(view).validate_request(content, None),
# content)
# def test_disabled_form_validator_get_bound_form_returns_none(self): # def test_disabled_form_validator_get_bound_form_returns_none(self):
# """If the view's form attribute is None on then # """If the view's form attribute is None on then
@ -36,7 +37,8 @@
# def test_disabled_model_form_validator_returns_content_unchanged(self): # def test_disabled_model_form_validator_returns_content_unchanged(self):
# """If the view's form is None and does not have a Resource with a model set then # """If the view's form is None and does not have a Resource with a model set then
# ModelFormValidator(view).validate_request(content, None) should just return the content unmodified.""" # ModelFormValidator(view).validate_request(content, None) should just
# return the content unmodified."""
# class DisabledModelFormView(View): # class DisabledModelFormView(View):
# resource = ModelResource # resource = ModelResource
@ -120,14 +122,16 @@
# def validation_failure_raises_response_exception(self, validator): # def validation_failure_raises_response_exception(self, validator):
# """If form validation fails a ResourceException 400 (Bad Request) should be raised.""" # """If form validation fails a ResourceException 400 (Bad Request) should be raised."""
# content = {} # content = {}
# self.assertRaises(ImmediateResponse, validator.validate_request, content, None) # self.assertRaises(ImmediateResponse, validator.validate_request,
# content, None)
# def validation_does_not_allow_extra_fields_by_default(self, validator): # def validation_does_not_allow_extra_fields_by_default(self, validator):
# """If some (otherwise valid) content includes fields that are not in the form then validation should fail. # """If some (otherwise valid) content includes fields that are not in the form then validation should fail.
# It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up # It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up
# broken clients more easily (eg submitting content with a misnamed field)""" # broken clients more easily (eg submitting content with a misnamed field)"""
# content = {'qwerty': 'uiop', 'extra': 'extra'} # content = {'qwerty': 'uiop', 'extra': 'extra'}
# self.assertRaises(ImmediateResponse, validator.validate_request, content, None) # self.assertRaises(ImmediateResponse, validator.validate_request,
# content, None)
# def validation_allows_extra_fields_if_explicitly_set(self, validator): # def validation_allows_extra_fields_if_explicitly_set(self, validator):
# """If we include an allowed_extra_fields paramater on _validate, then allow fields with those names.""" # """If we include an allowed_extra_fields paramater on _validate, then allow fields with those names."""
@ -147,7 +151,8 @@
# def validation_does_not_require_extra_fields_if_explicitly_set(self, validator): # def validation_does_not_require_extra_fields_if_explicitly_set(self, validator):
# """If we include an allowed_extra_fields paramater on _validate, then do not fail if we do not have fields with those names.""" # """If we include an allowed_extra_fields paramater on _validate, then do not fail if we do not have fields with those names."""
# content = {'qwerty': 'uiop'} # content = {'qwerty': 'uiop'}
# self.assertEqual(validator._validate(content, None, allowed_extra_fields=('extra',)), content) # self.assertEqual(validator._validate(content, None,
# allowed_extra_fields=('extra',)), content)
# def validation_failed_due_to_no_content_returns_appropriate_message(self, validator): # def validation_failed_due_to_no_content_returns_appropriate_message(self, validator):
# """If validation fails due to no content, ensure the response contains a single non-field error""" # """If validation fails due to no content, ensure the response contains a single non-field error"""
@ -198,7 +203,7 @@
# def test_form_validation_returns_content_unchanged_if_already_valid_and_clean(self): # def test_form_validation_returns_content_unchanged_if_already_valid_and_clean(self):
# validator = self.MockFormResource(self.MockFormView()) # validator = self.MockFormResource(self.MockFormView())
# self.validation_returns_content_unchanged_if_already_valid_and_clean(validator) # self.validation_returns_content_unchanged_if_already_valid_and_clean(validator)
# def test_form_validation_failure_raises_response_exception(self): # def test_form_validation_failure_raises_response_exception(self):
# validator = self.MockFormResource(self.MockFormView()) # validator = self.MockFormResource(self.MockFormView())
@ -214,33 +219,33 @@
# def test_validation_allows_unknown_fields_if_explicitly_allowed(self): # def test_validation_allows_unknown_fields_if_explicitly_allowed(self):
# validator = self.MockFormResource(self.MockFormView()) # validator = self.MockFormResource(self.MockFormView())
# self.validation_allows_unknown_fields_if_explicitly_allowed(validator) # self.validation_allows_unknown_fields_if_explicitly_allowed(validator)
# def test_validation_does_not_require_extra_fields_if_explicitly_set(self): # def test_validation_does_not_require_extra_fields_if_explicitly_set(self):
# validator = self.MockFormResource(self.MockFormView()) # validator = self.MockFormResource(self.MockFormView())
# self.validation_does_not_require_extra_fields_if_explicitly_set(validator) # self.validation_does_not_require_extra_fields_if_explicitly_set(validator)
# def test_validation_failed_due_to_no_content_returns_appropriate_message(self): # def test_validation_failed_due_to_no_content_returns_appropriate_message(self):
# validator = self.MockFormResource(self.MockFormView()) # validator = self.MockFormResource(self.MockFormView())
# self.validation_failed_due_to_no_content_returns_appropriate_message(validator) # self.validation_failed_due_to_no_content_returns_appropriate_message(validator)
# def test_validation_failed_due_to_field_error_returns_appropriate_message(self): # def test_validation_failed_due_to_field_error_returns_appropriate_message(self):
# validator = self.MockFormResource(self.MockFormView()) # validator = self.MockFormResource(self.MockFormView())
# self.validation_failed_due_to_field_error_returns_appropriate_message(validator) # self.validation_failed_due_to_field_error_returns_appropriate_message(validator)
# def test_validation_failed_due_to_invalid_field_returns_appropriate_message(self): # def test_validation_failed_due_to_invalid_field_returns_appropriate_message(self):
# validator = self.MockFormResource(self.MockFormView()) # validator = self.MockFormResource(self.MockFormView())
# self.validation_failed_due_to_invalid_field_returns_appropriate_message(validator) # self.validation_failed_due_to_invalid_field_returns_appropriate_message(validator)
# def test_validation_failed_due_to_multiple_errors_returns_appropriate_message(self): # def test_validation_failed_due_to_multiple_errors_returns_appropriate_message(self):
# validator = self.MockFormResource(self.MockFormView()) # validator = self.MockFormResource(self.MockFormView())
# self.validation_failed_due_to_multiple_errors_returns_appropriate_message(validator) # self.validation_failed_due_to_multiple_errors_returns_appropriate_message(validator)
# # Same tests on ModelResource # # Same tests on ModelResource
# def test_modelform_validation_returns_content_unchanged_if_already_valid_and_clean(self): # def test_modelform_validation_returns_content_unchanged_if_already_valid_and_clean(self):
# validator = self.MockModelResource(self.MockModelFormView()) # validator = self.MockModelResource(self.MockModelFormView())
# self.validation_returns_content_unchanged_if_already_valid_and_clean(validator) # self.validation_returns_content_unchanged_if_already_valid_and_clean(validator)
# def test_modelform_validation_failure_raises_response_exception(self): # def test_modelform_validation_failure_raises_response_exception(self):
# validator = self.MockModelResource(self.MockModelFormView()) # validator = self.MockModelResource(self.MockModelFormView())
@ -256,23 +261,23 @@
# def test_modelform_validation_does_not_require_extra_fields_if_explicitly_set(self): # def test_modelform_validation_does_not_require_extra_fields_if_explicitly_set(self):
# validator = self.MockModelResource(self.MockModelFormView()) # validator = self.MockModelResource(self.MockModelFormView())
# self.validation_does_not_require_extra_fields_if_explicitly_set(validator) # self.validation_does_not_require_extra_fields_if_explicitly_set(validator)
# def test_modelform_validation_failed_due_to_no_content_returns_appropriate_message(self): # def test_modelform_validation_failed_due_to_no_content_returns_appropriate_message(self):
# validator = self.MockModelResource(self.MockModelFormView()) # validator = self.MockModelResource(self.MockModelFormView())
# self.validation_failed_due_to_no_content_returns_appropriate_message(validator) # self.validation_failed_due_to_no_content_returns_appropriate_message(validator)
# def test_modelform_validation_failed_due_to_field_error_returns_appropriate_message(self): # def test_modelform_validation_failed_due_to_field_error_returns_appropriate_message(self):
# validator = self.MockModelResource(self.MockModelFormView()) # validator = self.MockModelResource(self.MockModelFormView())
# self.validation_failed_due_to_field_error_returns_appropriate_message(validator) # self.validation_failed_due_to_field_error_returns_appropriate_message(validator)
# def test_modelform_validation_failed_due_to_invalid_field_returns_appropriate_message(self): # def test_modelform_validation_failed_due_to_invalid_field_returns_appropriate_message(self):
# validator = self.MockModelResource(self.MockModelFormView()) # validator = self.MockModelResource(self.MockModelFormView())
# self.validation_failed_due_to_invalid_field_returns_appropriate_message(validator) # self.validation_failed_due_to_invalid_field_returns_appropriate_message(validator)
# def test_modelform_validation_failed_due_to_multiple_errors_returns_appropriate_message(self): # def test_modelform_validation_failed_due_to_multiple_errors_returns_appropriate_message(self):
# validator = self.MockModelResource(self.MockModelFormView()) # validator = self.MockModelResource(self.MockModelFormView())
# self.validation_failed_due_to_multiple_errors_returns_appropriate_message(validator) # self.validation_failed_due_to_multiple_errors_returns_appropriate_message(validator)
# class TestModelFormValidator(TestCase): # class TestModelFormValidator(TestCase):
@ -299,26 +304,30 @@
# def test_property_fields_are_allowed_on_model_forms(self): # def test_property_fields_are_allowed_on_model_forms(self):
# """Validation on ModelForms may include property fields that exist on the Model to be included in the input.""" # """Validation on ModelForms may include property fields that exist on the Model to be included in the input."""
# content = {'qwerty': 'example', 'uiop': 'example', 'read_only': 'read only'} # content = {'qwerty': 'example', 'uiop': 'example', 'read_only': 'read only'}
# self.assertEqual(self.validator.validate_request(content, None), content) # self.assertEqual(self.validator.validate_request(content, None),
# content)
# def test_property_fields_are_not_required_on_model_forms(self): # def test_property_fields_are_not_required_on_model_forms(self):
# """Validation on ModelForms does not require property fields that exist on the Model to be included in the input.""" # """Validation on ModelForms does not require property fields that exist on the Model to be included in the input."""
# content = {'qwerty': 'example', 'uiop': 'example'} # content = {'qwerty': 'example', 'uiop': 'example'}
# self.assertEqual(self.validator.validate_request(content, None), content) # self.assertEqual(self.validator.validate_request(content, None),
# content)
# def test_extra_fields_not_allowed_on_model_forms(self): # def test_extra_fields_not_allowed_on_model_forms(self):
# """If some (otherwise valid) content includes fields that are not in the form then validation should fail. # """If some (otherwise valid) content includes fields that are not in the form then validation should fail.
# It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up # It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up
# broken clients more easily (eg submitting content with a misnamed field)""" # broken clients more easily (eg submitting content with a misnamed field)"""
# content = {'qwerty': 'example', 'uiop': 'example', 'read_only': 'read only', 'extra': 'extra'} # content = {'qwerty': 'example', 'uiop': 'example', 'read_only': 'read only', 'extra': 'extra'}
# self.assertRaises(ImmediateResponse, self.validator.validate_request, content, None) # self.assertRaises(ImmediateResponse, self.validator.validate_request,
# content, None)
# def test_validate_requires_fields_on_model_forms(self): # def test_validate_requires_fields_on_model_forms(self):
# """If some (otherwise valid) content includes fields that are not in the form then validation should fail. # """If some (otherwise valid) content includes fields that are not in the form then validation should fail.
# It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up # It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up
# broken clients more easily (eg submitting content with a misnamed field)""" # broken clients more easily (eg submitting content with a misnamed field)"""
# content = {'read_only': 'read only'} # content = {'read_only': 'read only'}
# self.assertRaises(ImmediateResponse, self.validator.validate_request, content, None) # self.assertRaises(ImmediateResponse, self.validator.validate_request,
# content, None)
# def test_validate_does_not_require_blankable_fields_on_model_forms(self): # def test_validate_does_not_require_blankable_fields_on_model_forms(self):
# """Test standard ModelForm validation behaviour - fields with blank=True are not required.""" # """Test standard ModelForm validation behaviour - fields with blank=True are not required."""
@ -326,4 +335,5 @@
# self.validator.validate_request(content, None) # self.validator.validate_request(content, None)
# def test_model_form_validator_uses_model_forms(self): # def test_model_form_validator_uses_model_forms(self):
# self.assertTrue(isinstance(self.validator.get_bound_form(), forms.ModelForm)) # self.assertTrue(isinstance(self.validator.get_bound_form(),
# forms.ModelForm))

View File

@ -4,4 +4,4 @@ Blank URLConf just to keep runtests.py happy.
from rest_framework.compat import patterns from rest_framework.compat import patterns
urlpatterns = patterns('', urlpatterns = patterns('',
) )

View File

@ -9,7 +9,7 @@ class RequestFactory(RequestFactory):
super(RequestFactory, self).__init__(**defaults) super(RequestFactory, self).__init__(**defaults)
def patch(self, path, data={}, content_type=MULTIPART_CONTENT, def patch(self, path, data={}, content_type=MULTIPART_CONTENT,
**extra): **extra):
"Construct a PATCH request." "Construct a PATCH request."
patch_data = self._encode_data(data, content_type) patch_data = self._encode_data(data, content_type)
@ -17,11 +17,11 @@ class RequestFactory(RequestFactory):
parsed = urlparse(path) parsed = urlparse(path)
r = { r = {
'CONTENT_LENGTH': len(patch_data), 'CONTENT_LENGTH': len(patch_data),
'CONTENT_TYPE': content_type, 'CONTENT_TYPE': content_type,
'PATH_INFO': self._get_path(parsed), 'PATH_INFO': self._get_path(parsed),
'QUERY_STRING': parsed[4], 'QUERY_STRING': parsed[4],
'REQUEST_METHOD': 'PATCH', 'REQUEST_METHOD': 'PATCH',
'wsgi.input': FakePayload(patch_data), 'wsgi.input': FakePayload(patch_data),
} }
r.update(extra) r.update(extra)
return self.request(**r) return self.request(**r)

View File

@ -16,7 +16,8 @@ def apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required):
patterns = apply_suffix_patterns(urlpattern.url_patterns, patterns = apply_suffix_patterns(urlpattern.url_patterns,
suffix_pattern, suffix_pattern,
suffix_required) suffix_required)
ret.append(url(regex, include(patterns, namespace, app_name), kwargs)) ret.append(
url(regex, include(patterns, namespace, app_name), kwargs))
else: else:
# Regular URL pattern # Regular URL pattern

View File

@ -18,6 +18,7 @@ from rest_framework.compat import patterns, url
template_name = {'template_name': 'rest_framework/login.html'} template_name = {'template_name': 'rest_framework/login.html'}
urlpatterns = patterns('django.contrib.auth.views', urlpatterns = patterns('django.contrib.auth.views',
url(r'^login/$', 'login', template_name, name='login'), url(r'^login/$', 'login', template_name, name='login'),
url(r'^logout/$', 'logout', template_name, name='logout'), url(r'^logout/$', 'logout',
) template_name, name='logout'),
)

View File

@ -19,19 +19,21 @@ class XML2Dict(object):
for (k, v) in node.attrib.items(): for (k, v) in node.attrib.items():
k, v = self._namespace_split(k, v) k, v = self._namespace_split(k, v)
node_tree[k] = v node_tree[k] = v
#Save childrens # Save childrens
for child in node.getchildren(): for child in node.getchildren():
tag, tree = self._namespace_split(child.tag, self._parse_node(child)) tag, tree = self._namespace_split(
if tag not in node_tree: # the first time, so store it in dict child.tag, self._parse_node(child))
if tag not in node_tree: # the first time, so store it in dict
node_tree[tag] = tree node_tree[tag] = tree
continue continue
old = node_tree[tag] old = node_tree[tag]
if not isinstance(old, list): if not isinstance(old, list):
node_tree.pop(tag) node_tree.pop(tag)
node_tree[tag] = [old] # multi times, so change old dict to a list node_tree[tag] = [old]
# multi times, so change old dict to a list
node_tree[tag].append(tree) # add the new one node_tree[tag].append(tree) # add the new one
return node_tree return node_tree
def _namespace_split(self, tag, value): def _namespace_split(self, tag, value):
""" """
@ -52,7 +54,8 @@ class XML2Dict(object):
def fromstring(self, s): def fromstring(self, s):
"""parse a string""" """parse a string"""
t = ET.fromstring(s) t = ET.fromstring(s)
unused_root_tag, root_tree = self._namespace_split(t.tag, self._parse_node(t)) unused_root_tag, root_tree = self._namespace_split(
t.tag, self._parse_node(t))
return root_tree return root_tree

View File

@ -14,12 +14,14 @@ def get_breadcrumbs(url):
except Exception: except Exception:
pass pass
else: else:
# Check if this is a REST framework view, and if so add it to the breadcrumbs # Check if this is a REST framework view, and if so add it to the
# breadcrumbs
if isinstance(getattr(view, 'cls_instance', None), APIView): if isinstance(getattr(view, 'cls_instance', None), APIView):
# Don't list the same view twice in a row. # Don't list the same view twice in a row.
# Probably an optional trailing slash. # Probably an optional trailing slash.
if not seen or seen[-1] != view: if not seen or seen[-1] != view:
breadcrumbs_list.insert(0, (view.cls_instance.get_name(), prefix + url)) breadcrumbs_list.insert(
0, (view.cls_instance.get_name(), prefix + url))
seen.append(view) seen.append(view)
if url == '': if url == '':
@ -27,10 +29,12 @@ def get_breadcrumbs(url):
return breadcrumbs_list return breadcrumbs_list
elif url.endswith('/'): elif url.endswith('/'):
# Drop trailing slash off the end and continue to try to resolve more breadcrumbs # Drop trailing slash off the end and continue to try to resolve
# more breadcrumbs
return breadcrumbs_recursive(url.rstrip('/'), breadcrumbs_list, prefix, seen) return breadcrumbs_recursive(url.rstrip('/'), breadcrumbs_list, prefix, seen)
# Drop trailing non-slash off the end and continue to try to resolve more breadcrumbs # Drop trailing non-slash off the end and continue to try to resolve
# more breadcrumbs
return breadcrumbs_recursive(url[:url.rfind('/') + 1], breadcrumbs_list, prefix, seen) return breadcrumbs_recursive(url[:url.rfind('/') + 1], breadcrumbs_list, prefix, seen)
prefix = get_script_prefix().rstrip('/') prefix = get_script_prefix().rstrip('/')

View File

@ -84,10 +84,10 @@ else:
return node return node
SafeDumper.add_representer(SortedDict, SafeDumper.add_representer(SortedDict,
yaml.representer.SafeRepresenter.represent_dict) yaml.representer.SafeRepresenter.represent_dict)
SafeDumper.add_representer(DictWithMetadata, SafeDumper.add_representer(DictWithMetadata,
yaml.representer.SafeRepresenter.represent_dict) yaml.representer.SafeRepresenter.represent_dict)
SafeDumper.add_representer(SortedDictWithMetadata, SafeDumper.add_representer(SortedDictWithMetadata,
yaml.representer.SafeRepresenter.represent_dict) yaml.representer.SafeRepresenter.represent_dict)
SafeDumper.add_representer(types.GeneratorType, SafeDumper.add_representer(types.GeneratorType,
yaml.representer.SafeRepresenter.represent_list) yaml.representer.SafeRepresenter.represent_list)

View File

@ -56,7 +56,7 @@ class _MediaType(object):
if key != 'q' and other.params.get(key, None) != self.params.get(key, None): if key != 'q' and other.params.get(key, None) != self.params.get(key, None):
return False return False
if self.sub_type != '*' and other.sub_type != '*' and other.sub_type != self.sub_type: if self.sub_type != '*' and other.sub_type != '*' and other.sub_type != self.sub_type:
return False return False
if self.main_type != '*' and other.main_type != '*' and other.main_type != self.main_type: if self.main_type != '*' and other.main_type != '*' and other.main_type != self.main_type:

View File

@ -36,7 +36,8 @@ def _remove_leading_indent(content):
# unindent the content if needed # unindent the content if needed
if whitespace_counts: if whitespace_counts:
whitespace_pattern = '^' + (' ' * min(whitespace_counts)) whitespace_pattern = '^' + (' ' * min(whitespace_counts))
content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), '', content) content = re.sub(
re.compile(whitespace_pattern, re.MULTILINE), '', content)
content = content.strip('\n') content = content.strip('\n')
return content return content
@ -299,7 +300,8 @@ class APIView(View):
self.permission_denied(request) self.permission_denied(request)
self.check_throttles(request) self.check_throttles(request)
# Perform content negotiation and store the accepted info on the request # Perform content negotiation and store the accepted info on the
# request
neg = self.perform_content_negotiation(request) neg = self.perform_content_negotiation(request)
request.accepted_renderer, request.accepted_media_type = neg request.accepted_renderer, request.accepted_media_type = neg
@ -383,7 +385,8 @@ class APIView(View):
except Exception as exc: except Exception as exc:
response = self.handle_exception(exc) response = self.handle_exception(exc)
self.response = self.finalize_response(request, response, *args, **kwargs) self.response = self.finalize_response(
request, response, *args, **kwargs)
return self.response return self.response
def options(self, request, *args, **kwargs): def options(self, request, *args, **kwargs):