Travis will now check if django rest framework conforms to pep8 before running the tests. Added autopep8 to the development process. Applied it for the first time.

This commit is contained in:
Omer Katz 2013-01-25 19:46:37 +03:00
parent e92a01224d
commit 862cafa81a
58 changed files with 763 additions and 458 deletions

View File

@ -14,6 +14,8 @@ install:
- pip install django-filter==0.5.4 --use-mirrors - pip install django-filter==0.5.4 --use-mirrors
- pip install -r development.txt - pip install -r development.txt
- export PYTHONPATH=. - export PYTHONPATH=.
before-script:
- pep8 --exclude=migrations --ignore="E501,E255,E261,W191,E101" .
script: script:
- python rest_framework/runtests/runtests.py - python rest_framework/runtests/runtests.py
- python rest_framework/runtests/runcoverage.py - python rest_framework/runtests/runcoverage.py

View File

@ -80,6 +80,10 @@ To start hacking type.
pip install -r development.txt pip install -r development.txt
Before pushing run.
autopep8 -r --in-place ./rest_framework/
To run the tests. To run the tests.
./rest_framework/runtests/runtests.py ./rest_framework/runtests/runtests.py

View File

@ -1,3 +1,4 @@
autopep>=0.8.5
pep>=1.4.1 pep>=1.4.1
coverage>=3.6 coverage>=3.6
django-discover-runner>=0.2.2 django-discover-runner>=0.2.2

View File

@ -18,18 +18,18 @@ class Migration(SchemaMigration):
def forwards(self, orm): def forwards(self, orm):
# Adding model 'Token' # Adding model 'Token'
db.create_table('authtoken_token', ( db.create_table('authtoken_token', (
('key', self.gf('django.db.models.fields.CharField')(max_length=40, primary_key=True)), ('key', self.gf('django.db.models.fields.CharField')
(max_length=40, primary_key=True)),
('user', self.gf('django.db.models.fields.related.OneToOneField')(related_name='auth_token', unique=True, to=orm['%s.%s' % (User._meta.app_label, User._meta.object_name)])), ('user', self.gf('django.db.models.fields.related.OneToOneField')(related_name='auth_token', unique=True, to=orm['%s.%s' % (User._meta.app_label, User._meta.object_name)])),
('created', self.gf('django.db.models.fields.DateTimeField')(auto_now_add=True, blank=True)), ('created', self.gf('django.db.models.fields.DateTimeField')
(auto_now_add=True, blank=True)),
)) ))
db.send_create_signal('authtoken', ['Token']) db.send_create_signal('authtoken', ['Token'])
def backwards(self, orm): def backwards(self, orm):
# Deleting model 'Token' # Deleting model 'Token'
db.delete_table('authtoken_token') db.delete_table('authtoken_token')
models = { models = {
'auth.group': { 'auth.group': {
'Meta': {'object_name': 'Group'}, 'Meta': {'object_name': 'Group'},

View File

@ -15,10 +15,13 @@ class AuthTokenSerializer(serializers.Serializer):
if user: if user:
if not user.is_active: if not user.is_active:
raise serializers.ValidationError('User account is disabled.') raise serializers.ValidationError(
'User account is disabled.')
attrs['user'] = user attrs['user'] = user
return attrs return attrs
else: else:
raise serializers.ValidationError('Unable to login with provided credentials.') raise serializers.ValidationError(
'Unable to login with provided credentials.')
else: else:
raise serializers.ValidationError('Must include "username" and "password"') raise serializers.ValidationError(
'Must include "username" and "password"')

View File

@ -10,7 +10,8 @@ from rest_framework.authtoken.serializers import AuthTokenSerializer
class ObtainAuthToken(APIView): class ObtainAuthToken(APIView):
throttle_classes = () throttle_classes = ()
permission_classes = () permission_classes = ()
parser_classes = (parsers.FormParser, parsers.MultiPartParser, parsers.JSONParser,) parser_classes = (
parsers.FormParser, parsers.MultiPartParser, parsers.JSONParser,)
renderer_classes = (renderers.JSONRenderer,) renderer_classes = (renderers.JSONRenderer,)
serializer_class = AuthTokenSerializer serializer_class = AuthTokenSerializer
model = Token model = Token
@ -18,7 +19,8 @@ class ObtainAuthToken(APIView):
def post(self, request): def post(self, request):
serializer = self.serializer_class(data=request.DATA) serializer = self.serializer_class(data=request.DATA)
if serializer.is_valid(): if serializer.is_valid():
token, created = Token.objects.get_or_create(user=serializer.object['user']) token, created = Token.objects.get_or_create(
user=serializer.object['user'])
return Response({'token': token.key}) return Response({'token': token.key})
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)

View File

@ -100,7 +100,8 @@ else:
# https://github.com/markotibold/django-rest-framework/tree/patch # https://github.com/markotibold/django-rest-framework/tree/patch
http_method_names = set(View.http_method_names) http_method_names = set(View.http_method_names)
http_method_names.add('patch') http_method_names.add('patch')
View.http_method_names = list(http_method_names) # PATCH method is not implemented by Django View.http_method_names = list(
http_method_names) # PATCH method is not implemented by Django
# PUT, DELETE do not require CSRF until 1.4. They should. Make it better. # PUT, DELETE do not require CSRF until 1.4. They should. Make it better.
if django.VERSION >= (1, 4): if django.VERSION >= (1, 4):
@ -184,7 +185,8 @@ else:
def _sanitize_token(token): def _sanitize_token(token):
# Allow only alphanum, and ensure we return a 'str' for the sake of the post # Allow only alphanum, and ensure we return a 'str' for the sake of the post
# processing middleware. # processing middleware.
token = re.sub('[^a-zA-Z0-9]', '', str(token.decode('ascii', 'ignore'))) token = re.sub(
'[^a-zA-Z0-9]', '', str(token.decode('ascii', 'ignore')))
if token == "": if token == "":
# In case the cookie has been truncated to nothing at some point. # In case the cookie has been truncated to nothing at some point.
return _get_new_csrf_key() return _get_new_csrf_key()
@ -218,12 +220,14 @@ else:
return None return None
try: try:
csrf_token = _sanitize_token(request.COOKIES[settings.CSRF_COOKIE_NAME]) csrf_token = _sanitize_token(
request.COOKIES[settings.CSRF_COOKIE_NAME])
# Use same token next time # Use same token next time
request.META['CSRF_COOKIE'] = csrf_token request.META['CSRF_COOKIE'] = csrf_token
except KeyError: except KeyError:
csrf_token = None csrf_token = None
# Generate token and store it in the request, so it's available to the view. # Generate token and store it in the request, so it's available
# to the view.
request.META["CSRF_COOKIE"] = _get_new_csrf_key() request.META["CSRF_COOKIE"] = _get_new_csrf_key()
# Wait until request.META["CSRF_COOKIE"] has been manipulated before # Wait until request.META["CSRF_COOKIE"] has been manipulated before
@ -231,7 +235,8 @@ else:
if getattr(callback, 'csrf_exempt', False): if getattr(callback, 'csrf_exempt', False):
return None return None
# Assume that anything not defined as 'safe' by RC2616 needs protection. # Assume that anything not defined as 'safe' by RC2616 needs
# protection.
if request.method not in ('GET', 'HEAD', 'OPTIONS', 'TRACE'): if request.method not in ('GET', 'HEAD', 'OPTIONS', 'TRACE'):
if getattr(request, '_dont_enforce_csrf_checks', False): if getattr(request, '_dont_enforce_csrf_checks', False):
# Mechanism to turn off CSRF checks for test suite. It comes after # Mechanism to turn off CSRF checks for test suite. It comes after
@ -258,7 +263,9 @@ else:
# we can use strict Referer checking. # we can use strict Referer checking.
referer = request.META.get('HTTP_REFERER') referer = request.META.get('HTTP_REFERER')
if referer is None: if referer is None:
logger.warning('Forbidden (%s): %s' % (REASON_NO_REFERER, request.path), logger.warning(
'Forbidden (%s): %s' % (
REASON_NO_REFERER, request.path),
extra={ extra={
'status_code': 403, 'status_code': 403,
'request': request, 'request': request,
@ -270,7 +277,8 @@ else:
good_referer = 'https://%s/' % request.get_host() good_referer = 'https://%s/' % request.get_host()
if not same_origin(referer, good_referer): if not same_origin(referer, good_referer):
reason = REASON_BAD_REFERER % (referer, good_referer) reason = REASON_BAD_REFERER % (referer, good_referer)
logger.warning('Forbidden (%s): %s' % (reason, request.path), logger.warning(
'Forbidden (%s): %s' % (reason, request.path),
extra={ extra={
'status_code': 403, 'status_code': 403,
'request': request, 'request': request,
@ -282,7 +290,9 @@ else:
# No CSRF cookie. For POST requests, we insist on a CSRF cookie, # No CSRF cookie. For POST requests, we insist on a CSRF cookie,
# and in this way we can avoid all CSRF attacks, including login # and in this way we can avoid all CSRF attacks, including login
# CSRF. # CSRF.
logger.warning('Forbidden (%s): %s' % (REASON_NO_CSRF_COOKIE, request.path), logger.warning(
'Forbidden (%s): %s' % (
REASON_NO_CSRF_COOKIE, request.path),
extra={ extra={
'status_code': 403, 'status_code': 403,
'request': request, 'request': request,
@ -293,15 +303,19 @@ else:
# check non-cookie token for match # check non-cookie token for match
request_csrf_token = "" request_csrf_token = ""
if request.method == "POST": if request.method == "POST":
request_csrf_token = request.POST.get('csrfmiddlewaretoken', '') request_csrf_token = request.POST.get(
'csrfmiddlewaretoken', '')
if request_csrf_token == "": if request_csrf_token == "":
# Fall back to X-CSRFToken, to make things easier for AJAX, # Fall back to X-CSRFToken, to make things easier for AJAX,
# and possible for PUT/DELETE # and possible for PUT/DELETE
request_csrf_token = request.META.get('HTTP_X_CSRFTOKEN', '') request_csrf_token = request.META.get(
'HTTP_X_CSRFTOKEN', '')
if not constant_time_compare(request_csrf_token, csrf_token): if not constant_time_compare(request_csrf_token, csrf_token):
logger.warning('Forbidden (%s): %s' % (REASON_BAD_TOKEN, request.path), logger.warning(
'Forbidden (%s): %s' % (
REASON_BAD_TOKEN, request.path),
extra={ extra={
'status_code': 403, 'status_code': 403,
'request': request, 'request': request,

View File

@ -30,10 +30,12 @@ def api_view(http_method_names):
# api_view applied with eg. string instead of list of strings # api_view applied with eg. string instead of list of strings
assert isinstance(http_method_names, (list, tuple)), \ assert isinstance(http_method_names, (list, tuple)), \
'@api_view expected a list of strings, recieved %s' % type(http_method_names).__name__ '@api_view expected a list of strings, recieved %s' % type(
http_method_names).__name__
allowed_methods = set(http_method_names) | set(('options',)) allowed_methods = set(http_method_names) | set(('options',))
WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods] WrappedAPIView.http_method_names = [method.lower(
) for method in allowed_methods]
def handler(self, *args, **kwargs): def handler(self, *args, **kwargs):
return func(*args, **kwargs) return func(*args, **kwargs)
@ -49,7 +51,8 @@ def api_view(http_method_names):
WrappedAPIView.parser_classes = getattr(func, 'parser_classes', WrappedAPIView.parser_classes = getattr(func, 'parser_classes',
APIView.parser_classes) APIView.parser_classes)
WrappedAPIView.authentication_classes = getattr(func, 'authentication_classes', WrappedAPIView.authentication_classes = getattr(
func, 'authentication_classes',
APIView.authentication_classes) APIView.authentication_classes)
WrappedAPIView.throttle_classes = getattr(func, 'throttle_classes', WrappedAPIView.throttle_classes = getattr(func, 'throttle_classes',

View File

@ -228,9 +228,11 @@ class ModelField(WritableField):
super(ModelField, self).__init__(*args, **kwargs) super(ModelField, self).__init__(*args, **kwargs)
if self.min_length is not None: if self.min_length is not None:
self.validators.append(validators.MinLengthValidator(self.min_length)) self.validators.append(
validators.MinLengthValidator(self.min_length))
if self.max_length is not None: if self.max_length is not None:
self.validators.append(validators.MaxLengthValidator(self.max_length)) self.validators.append(
validators.MaxLengthValidator(self.max_length))
def from_native(self, value): def from_native(self, value):
rel = getattr(self.model_field, "rel", None) rel = getattr(self.model_field, "rel", None)
@ -349,7 +351,8 @@ class ChoiceField(WritableField):
""" """
super(ChoiceField, self).validate(value) super(ChoiceField, self).validate(value)
if value and not self.valid_value(value): if value and not self.valid_value(value):
raise ValidationError(self.error_messages['invalid_choice'] % {'value': value}) raise ValidationError(
self.error_messages['invalid_choice'] % {'value': value})
def valid_value(self, value): def valid_value(self, value):
""" """
@ -395,7 +398,8 @@ class RegexField(CharField):
form_field_class = forms.RegexField form_field_class = forms.RegexField
def __init__(self, regex, max_length=None, min_length=None, *args, **kwargs): def __init__(self, regex, max_length=None, min_length=None, *args, **kwargs):
super(RegexField, self).__init__(max_length, min_length, *args, **kwargs) super(RegexField, self).__init__(max_length, min_length, *
args, **kwargs)
self.regex = regex self.regex = regex
def _get_regex(self): def _get_regex(self):
@ -595,7 +599,8 @@ class FileField(WritableField):
if self.max_length is not None and len(file_name) > self.max_length: if self.max_length is not None and len(file_name) > self.max_length:
error_values = {'max': self.max_length, 'length': len(file_name)} error_values = {'max': self.max_length, 'length': len(file_name)}
raise ValidationError(self.error_messages['max_length'] % error_values) raise ValidationError(
self.error_messages['max_length'] % error_values)
if not file_name: if not file_name:
raise ValidationError(self.error_messages['invalid']) raise ValidationError(self.error_messages['invalid'])
if not self.allow_empty_file and not file_size: if not self.allow_empty_file and not file_size:

View File

@ -15,7 +15,8 @@ class CreateModelMixin(object):
Should be mixed in with any `BaseView`. Should be mixed in with any `BaseView`.
""" """
def create(self, request, *args, **kwargs): def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.DATA, files=request.FILES) serializer = self.get_serializer(
data=request.DATA, files=request.FILES)
if serializer.is_valid(): if serializer.is_valid():
self.pre_save(serializer.object) self.pre_save(serializer.object)

View File

@ -43,7 +43,8 @@ class DefaultContentNegotiation(BaseContentNegotiation):
# Check the acceptable media types against each renderer, # Check the acceptable media types against each renderer,
# attempting more specific media types first # attempting more specific media types first
# NB. The inner loop here isn't as bad as it first looks :) # NB. The inner loop here isn't as bad as it first looks :)
# Worst case is we're looping over len(accept_list) * len(self.renderers) # Worst case is we're looping over len(accept_list) *
# len(self.renderers)
for media_type_set in order_by_precedence(accepts): for media_type_set in order_by_precedence(accepts):
for renderer in renderers: for renderer in renderers:
for media_type in media_type_set: for media_type in media_type_set:
@ -56,7 +57,8 @@ class DefaultContentNegotiation(BaseContentNegotiation):
return renderer, renderer.media_type return renderer, renderer.media_type
else: else:
# Eg client requests 'application/json; indent=8' # Eg client requests 'application/json; indent=8'
# Accepted media type is 'application/json; indent=8' # Accepted media type is 'application/json;
# indent=8'
return renderer, media_type return renderer, media_type
raise exceptions.NotAcceptable(available_renderers=renderers) raise exceptions.NotAcceptable(available_renderers=renderers)

View File

@ -68,7 +68,8 @@ class BasePaginationSerializer(serializers.Serializer):
else: else:
context_kwarg = {} context_kwarg = {}
self.fields[results_field] = object_serializer(source='object_list', **context_kwarg) self.fields[results_field] = object_serializer(
source='object_list', **context_kwarg)
def to_native(self, obj): def to_native(self, obj):
""" """

View File

@ -151,7 +151,8 @@ class XMLParser(BaseParser):
if len(children) == 0: if len(children) == 0:
return self._type_convert(element.text) return self._type_convert(element.text)
else: else:
# if the fist child tag is list-item means all children are list-item # if the fist child tag is list-item means all children are list-
# item
if children[0].tag == "list-item": if children[0].tag == "list-item":
data = [] data = []
for child in children: for child in children:

View File

@ -35,9 +35,11 @@ class RelatedField(WritableField):
super(RelatedField, self).initialize(parent, field_name) super(RelatedField, self).initialize(parent, field_name)
if self.queryset is None and not self.read_only: if self.queryset is None and not self.read_only:
try: try:
manager = getattr(self.parent.opts.model, self.source or field_name) manager = getattr(
self.parent.opts.model, self.source or field_name)
if hasattr(manager, 'related'): # Forward if hasattr(manager, 'related'): # Forward
self.queryset = manager.related.model._default_manager.all() self.queryset = manager.related.model._default_manager.all(
)
else: # Reverse else: # Reverse
self.queryset = manager.field.rel.to._default_manager.all() self.queryset = manager.field.rel.to._default_manager.all()
except: except:
@ -194,13 +196,15 @@ class PrimaryKeyRelatedField(RelatedField):
return desc return desc
return "%s - %s" % (desc, ident) return "%s - %s" % (desc, ident)
# TODO: Possibly change this to just take `obj`, through prob less performant # TODO: Possibly change this to just take `obj`, through prob less
# performant
def to_native(self, pk): def to_native(self, pk):
return pk return pk
def from_native(self, data): def from_native(self, data):
if self.queryset is None: if self.queryset is None:
raise Exception('Writable related fields must include a `queryset` argument') raise Exception(
'Writable related fields must include a `queryset` argument')
try: try:
return self.queryset.get(pk=data) return self.queryset.get(pk=data)
@ -268,7 +272,8 @@ class ManyPrimaryKeyRelatedField(ManyRelatedField):
def from_native(self, data): def from_native(self, data):
if self.queryset is None: if self.queryset is None:
raise Exception('Writable related fields must include a `queryset` argument') raise Exception(
'Writable related fields must include a `queryset` argument')
try: try:
return self.queryset.get(pk=data) return self.queryset.get(pk=data)
@ -302,7 +307,8 @@ class SlugRelatedField(RelatedField):
def from_native(self, data): def from_native(self, data):
if self.queryset is None: if self.queryset is None:
raise Exception('Writable related fields must include a `queryset` argument') raise Exception(
'Writable related fields must include a `queryset` argument')
try: try:
return self.queryset.get(**{self.slug_field: data}) return self.queryset.get(**{self.slug_field: data})
@ -394,10 +400,12 @@ class HyperlinkedRelatedField(RelatedField):
# Convert URL -> model instance pk # Convert URL -> model instance pk
# TODO: Use values_list # TODO: Use values_list
if self.queryset is None: if self.queryset is None:
raise Exception('Writable related fields must include a `queryset` argument') raise Exception(
'Writable related fields must include a `queryset` argument')
try: try:
http_prefix = value.startswith('http:') or value.startswith('https:') http_prefix = value.startswith(
'http:') or value.startswith('https:')
except AttributeError: except AttributeError:
msg = self.error_messages['incorrect_type'] msg = self.error_messages['incorrect_type']
raise ValidationError(msg % type(value).__name__) raise ValidationError(msg % type(value).__name__)

View File

@ -33,7 +33,8 @@ class BaseRenderer(object):
format = None format = None
def render(self, data, accepted_media_type=None, renderer_context=None): def render(self, data, accepted_media_type=None, renderer_context=None):
raise NotImplemented('Renderer class requires .render() to be implemented') raise NotImplemented(
'Renderer class requires .render() to be implemented')
class JSONRenderer(BaseRenderer): class JSONRenderer(BaseRenderer):
@ -206,7 +207,8 @@ class TemplateHTMLRenderer(BaseRenderer):
return [self.template_name] return [self.template_name]
elif hasattr(view, 'get_template_names'): elif hasattr(view, 'get_template_names'):
return view.get_template_names() return view.get_template_names()
raise ConfigurationError('Returned a template response with no template_name') raise ConfigurationError(
'Returned a template response with no template_name')
def get_exception_template(self, response): def get_exception_template(self, response):
template_names = [name % {'status_code': response.status_code} template_names = [name % {'status_code': response.status_code}
@ -356,7 +358,8 @@ class BrowsableAPIRenderer(BaseRenderer):
fields = self.serializer_to_form_fields(serializer) fields = self.serializer_to_form_fields(serializer)
# Creating an on the fly form see: # Creating an on the fly form see:
# http://stackoverflow.com/questions/3915024/dynamically-creating-classes-python # http://stackoverflow.com/questions/3915024/dynamically-creating-
# classes-python
OnTheFlyForm = type("OnTheFlyForm", (forms.Form,), fields) OnTheFlyForm = type("OnTheFlyForm", (forms.Form,), fields)
data = (obj is not None) and serializer.data or None data = (obj is not None) and serializer.data or None
form_instance = OnTheFlyForm(data) form_instance = OnTheFlyForm(data)
@ -370,7 +373,8 @@ class BrowsableAPIRenderer(BaseRenderer):
""" """
# If we're not using content overloading there's no point in supplying a generic form, # If we're not using content overloading there's no point in supplying a generic form,
# as the view won't treat the form's value as the content of the request. # as the view won't treat the form's value as the content of the
# request.
if not (api_settings.FORM_CONTENT_OVERRIDE if not (api_settings.FORM_CONTENT_OVERRIDE
and api_settings.FORM_CONTENTTYPE_OVERRIDE): and api_settings.FORM_CONTENTTYPE_OVERRIDE):
return None return None
@ -424,7 +428,8 @@ class BrowsableAPIRenderer(BaseRenderer):
response = renderer_context['response'] response = renderer_context['response']
renderer = self.get_default_renderer(view) renderer = self.get_default_renderer(view)
content = self.get_content(renderer, data, accepted_media_type, renderer_context) content = self.get_content(
renderer, data, accepted_media_type, renderer_context)
put_form = self.get_form(view, 'PUT', request) put_form = self.get_form(view, 'PUT', request)
post_form = self.get_form(view, 'POST', request) post_form = self.get_form(view, 'POST', request)

View File

@ -266,7 +266,8 @@ class Request(object):
self._data = self._request.POST self._data = self._request.POST
self._files = self._request.FILES self._files = self._request.FILES
# Method overloading - change the method and remove the param from the content. # Method overloading - change the method and remove the param from the
# content.
if (self._METHOD_PARAM and if (self._METHOD_PARAM and
self._METHOD_PARAM in self._data): self._METHOD_PARAM in self._data):
self._method = self._data[self._METHOD_PARAM].upper() self._method = self._data[self._METHOD_PARAM].upper()

View File

@ -21,6 +21,7 @@ except ImportError:
print("Coverage is not installed. Aborting...") print("Coverage is not installed. Aborting...")
exit(1) exit(1)
def report(cov, cov_files): def report(cov, cov_files):
pc = cov.report(cov_files) pc = cov.report(cov_files)
@ -32,6 +33,7 @@ def report(cov, cov_files):
return pc return pc
def prepare_report(project_dir): def prepare_report(project_dir):
cov_files = [] cov_files = []
@ -52,7 +54,8 @@ def prepare_report(project_dir):
if 'rest_framework.py' in files: if 'rest_framework.py' in files:
files.remove('rest_framework.py') files.remove('rest_framework.py')
cov_files.extend([os.path.join(path, file) for file in files if file.endswith('.py')]) cov_files.extend([os.path.join(
path, file) for file in files if file.endswith('.py')])
return cov_files return cov_files
@ -96,7 +99,7 @@ def main():
report(cov, cov_files) report(cov, cov_files)
pc = report(cov, cov_files) pc = report(cov, cov_files)
if failures <> 0: if failures != 0:
sys.exit(failures) sys.exit(failures)
if pc < settings.CODE_COVERAGE_THRESHOLD: if pc < settings.CODE_COVERAGE_THRESHOLD:

View File

@ -3,7 +3,8 @@ import os
import sys import sys
sys.path.append(os.path.join(os.path.dirname(__file__), "../..")) sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "rest_framework.tests.settings") os.environ.setdefault(
"DJANGO_SETTINGS_MODULE", "rest_framework.tests.settings")
from django.core.management import execute_from_command_line from django.core.management import execute_from_command_line

View File

@ -28,7 +28,8 @@ class DictWithMetadata(dict):
Overriden to remove metadata from the dict, since it shouldn't be pickled Overriden to remove metadata from the dict, since it shouldn't be pickled
and may in some instances be unpickleable. and may in some instances be unpickleable.
""" """
# return an instance of the first dict in MRO that isn't a DictWithMetadata # return an instance of the first dict in MRO that isn't a
# DictWithMetadata
for base in self.__class__.__mro__: for base in self.__class__.__mro__:
if not isinstance(base, DictWithMetadata) and isinstance(base, dict): if not isinstance(base, DictWithMetadata) and isinstance(base, dict):
return base(self) return base(self)
@ -230,12 +231,14 @@ class BaseSerializer(Field):
if field_name in self._errors: if field_name in self._errors:
continue continue
try: try:
validate_method = getattr(self, 'validate_%s' % field_name, None) validate_method = getattr(
self, 'validate_%s' % field_name, None)
if validate_method: if validate_method:
source = field.source or field_name source = field.source or field_name
attrs = validate_method(attrs, source) attrs = validate_method(attrs, source)
except ValidationError as err: except ValidationError as err:
self._errors[field_name] = self._errors.get(field_name, []) + list(err.messages) self._errors[field_name] = self._errors.get(
field_name, []) + list(err.messages)
# If there are already errors, we don't run .validate() because # If there are already errors, we don't run .validate() because
# field-validation failed and thus `attrs` may not be complete. # field-validation failed and thus `attrs` may not be complete.
@ -246,7 +249,8 @@ class BaseSerializer(Field):
except ValidationError as err: except ValidationError as err:
if hasattr(err, 'message_dict'): if hasattr(err, 'message_dict'):
for field_name, error_messages in err.message_dict.items(): for field_name, error_messages in err.message_dict.items():
self._errors[field_name] = self._errors.get(field_name, []) + list(error_messages) self._errors[field_name] = self._errors.get(
field_name, []) + list(error_messages)
elif hasattr(err, 'messages'): elif hasattr(err, 'messages'):
self._errors['non_field_errors'] = err.messages self._errors['non_field_errors'] = err.messages

View File

@ -116,7 +116,8 @@ def import_from_string(val, setting_name):
module = importlib.import_module(module_path) module = importlib.import_module(module_path)
return getattr(module, class_name) return getattr(module, class_name)
except ImportError as e: except ImportError as e:
msg = "Could not import '%s' for API setting '%s'. %s: %s." % (val, setting_name, e.__class__.__name__, e) msg = "Could not import '%s' for API setting '%s'. %s: %s." % (
val, setting_name, e.__class__.__name__, e)
raise ImportError(msg) raise ImportError(msg)

View File

@ -116,14 +116,17 @@ TRAILING_PUNCTUATION = ['.', ',', ')', '>', '\n', '&gt;', '"', "'"]
DOTS = ['&middot;', '*', '\xe2\x80\xa2', '&#149;', '&bull;', '&#8226;'] DOTS = ['&middot;', '*', '\xe2\x80\xa2', '&#149;', '&bull;', '&#8226;']
unencoded_ampersands_re = re.compile(r'&(?!(\w+|#\d+);)') unencoded_ampersands_re = re.compile(r'&(?!(\w+|#\d+);)')
word_split_re = re.compile(r'(\s+)') word_split_re = re.compile(r'(\s+)')
punctuation_re = re.compile('^(?P<lead>(?:%s)*)(?P<middle>.*?)(?P<trail>(?:%s)*)$' % \ punctuation_re = re.compile('^(?P<lead>(?:%s)*)(?P<middle>.*?)(?P<trail>(?:%s)*)$' %
('|'.join([re.escape(x) for x in LEADING_PUNCTUATION]), (
'|'.join([re.escape(
x) for x in LEADING_PUNCTUATION]),
'|'.join([re.escape(x) for x in TRAILING_PUNCTUATION]))) '|'.join([re.escape(x) for x in TRAILING_PUNCTUATION])))
simple_email_re = re.compile(r'^\S+@[a-zA-Z0-9._-]+\.[a-zA-Z0-9._-]+$') simple_email_re = re.compile(r'^\S+@[a-zA-Z0-9._-]+\.[a-zA-Z0-9._-]+$')
link_target_attribute_re = re.compile(r'(<a [^>]*?)target=[^\s>]+') link_target_attribute_re = re.compile(r'(<a [^>]*?)target=[^\s>]+')
html_gunk_re = re.compile(r'(?:<br clear="all">|<i><\/i>|<b><\/b>|<em><\/em>|<strong><\/strong>|<\/?smallcaps>|<\/?uppercase>)', re.IGNORECASE) html_gunk_re = re.compile(r'(?:<br clear="all">|<i><\/i>|<b><\/b>|<em><\/em>|<strong><\/strong>|<\/?smallcaps>|<\/?uppercase>)', re.IGNORECASE)
hard_coded_bullets_re = re.compile(r'((?:<p>(?:%s).*?[a-zA-Z].*?</p>\s*)+)' % '|'.join([re.escape(x) for x in DOTS]), re.DOTALL) hard_coded_bullets_re = re.compile(r'((?:<p>(?:%s).*?[a-zA-Z].*?</p>\s*)+)' % '|'.join([re.escape(x) for x in DOTS]), re.DOTALL)
trailing_empty_content_re = re.compile(r'(?:<p>(?:&nbsp;|\s|<br \/>)*?</p>\s*)+\Z') trailing_empty_content_re = re.compile(
r'(?:<p>(?:&nbsp;|\s|<br \/>)*?</p>\s*)+\Z')
# And the template tags themselves... # And the template tags themselves...
@ -211,7 +214,8 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
If autoescape is True, the link text and URLs will get autoescaped. If autoescape is True, the link text and URLs will get autoescaped.
""" """
trim_url = lambda x, limit=trim_url_limit: limit is not None and (len(x) > limit and ('%s...' % x[:max(0, limit - 3)])) or x trim_url = lambda x, limit=trim_url_limit: limit is not None and (
len(x) > limit and ('%s...' % x[:max(0, limit - 3)])) or x
safe_input = isinstance(text, SafeData) safe_input = isinstance(text, SafeData)
words = word_split_re.split(force_unicode(text)) words = word_split_re.split(force_unicode(text))
nofollow_attr = nofollow and ' rel="nofollow"' or '' nofollow_attr = nofollow and ' rel="nofollow"' or ''
@ -225,8 +229,8 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
url = None url = None
if middle.startswith('http://') or middle.startswith('https://'): if middle.startswith('http://') or middle.startswith('https://'):
url = middle url = middle
elif middle.startswith('www.') or ('@' not in middle and \ elif middle.startswith('www.') or ('@' not in middle and
middle and middle[0] in string.ascii_letters + string.digits and \ middle and middle[0] in string.ascii_letters + string.digits and
(middle.endswith('.org') or middle.endswith('.net') or middle.endswith('.com'))): (middle.endswith('.org') or middle.endswith('.net') or middle.endswith('.com'))):
url = 'http://%s' % middle url = 'http://%s' % middle
elif '@' in middle and not ':' in middle and simple_email_re.match(middle): elif '@' in middle and not ':' in middle and simple_email_re.match(middle):
@ -238,7 +242,8 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
if autoescape and not safe_input: if autoescape and not safe_input:
lead, trail = escape(lead), escape(trail) lead, trail = escape(lead), escape(trail)
url, trimmed = escape(url), escape(trimmed) url, trimmed = escape(url), escape(trimmed)
middle = '<a href="%s"%s>%s</a>' % (url, nofollow_attr, trimmed) middle = '<a href="%s"%s>%s</a>' % (
url, nofollow_attr, trimmed)
words[i] = mark_safe('%s%s%s' % (lead, middle, trail)) words[i] = mark_safe('%s%s%s' % (lead, middle, trail))
else: else:
if safe_input: if safe_input:

View File

@ -55,7 +55,8 @@
# group.save() # group.save()
# self.assertEqual(0, User.objects.count()) # self.assertEqual(0, User.objects.count())
# response = self.client.post('/users/', {'username': 'bar', 'password': 'baz', 'groups': [group.id]}) # response = self.client.post('/users/', {'username': 'bar', 'password':
# 'baz', 'groups': [group.id]})
# self.assertEqual(response.status_code, 201) # self.assertEqual(response.status_code, 201)
# self.assertEqual(1, User.objects.count()) # self.assertEqual(1, User.objects.count())
@ -77,7 +78,8 @@
# group.save() # group.save()
# self.assertEqual(0, User.objects.count()) # self.assertEqual(0, User.objects.count())
# response = self.client.post('/customusers/', {'username': 'bar', 'groups': [group.id]}) # response = self.client.post('/customusers/', {'username': 'bar',
# 'groups': [group.id]})
# self.assertEqual(response.status_code, 201) # self.assertEqual(response.status_code, 201)
# self.assertEqual(1, CustomUser.objects.count()) # self.assertEqual(1, CustomUser.objects.count())

View File

@ -13,7 +13,8 @@ MANAGERS = ADMINS
DATABASES = { DATABASES = {
'default': { 'default': {
'ENGINE': 'django.db.backends.sqlite3', 'ENGINE': 'django.db.backends.sqlite3',
# Add 'postgresql_psycopg2', 'postgresql', 'mysql', 'sqlite3' or 'oracle'. # Add 'postgresql_psycopg2', 'postgresql', 'mysql', 'sqlite3' or
# 'oracle'.
'NAME': 'sqlite.db', # Or path to database file if using sqlite3. 'NAME': 'sqlite.db', # Or path to database file if using sqlite3.
'USER': '', # Not used with sqlite3. 'USER': '', # Not used with sqlite3.
'PASSWORD': '', # Not used with sqlite3. 'PASSWORD': '', # Not used with sqlite3.
@ -114,7 +115,8 @@ if django.VERSION < (1, 3):
INSTALLED_APPS += ('staticfiles',) INSTALLED_APPS += ('staticfiles',)
# If we're running on the Jenkins server we want to archive the coverage reports as XML. # If we're running on the Jenkins server we want to archive the coverage
# reports as XML.
import os import os
if os.environ.get('HUDSON_URL', None): if os.environ.get('HUDSON_URL', None):

View File

@ -22,10 +22,14 @@ class MockView(APIView):
return HttpResponse({'a': 1, 'b': 2, 'c': 3}) return HttpResponse({'a': 1, 'b': 2, 'c': 3})
urlpatterns = patterns('', urlpatterns = patterns('',
(r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])), (r'^session/$', MockView.as_view(authentication_classes=[
(r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])), SessionAuthentication])),
(r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])), (r'^basic/$', MockView.as_view(authentication_classes=[
(r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'), BasicAuthentication])),
(r'^token/$', MockView.as_view(authentication_classes=[
TokenAuthentication])),
(r'^auth-token/$',
'rest_framework.authtoken.views.obtain_auth_token'),
) )
@ -38,18 +42,23 @@ class BasicAuthTests(TestCase):
self.username = 'john' self.username = 'john'
self.email = 'lennon@thebeatles.com' self.email = 'lennon@thebeatles.com'
self.password = 'password' self.password = 'password'
self.user = User.objects.create_user(self.username, self.email, self.password) self.user = User.objects.create_user(
self.username, self.email, self.password)
def test_post_form_passing_basic_auth(self): def test_post_form_passing_basic_auth(self):
"""Ensure POSTing json over basic auth with correct credentials passes and does not require CSRF""" """Ensure POSTing json over basic auth with correct credentials passes and does not require CSRF"""
auth = 'Basic %s' % base64.encodestring('%s:%s' % (self.username, self.password)).strip() auth = 'Basic %s' % base64.encodestring(
response = self.csrf_client.post('/basic/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) '%s:%s' % (self.username, self.password)).strip()
response = self.csrf_client.post(
'/basic/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_post_json_passing_basic_auth(self): def test_post_json_passing_basic_auth(self):
"""Ensure POSTing form over basic auth with correct credentials passes and does not require CSRF""" """Ensure POSTing form over basic auth with correct credentials passes and does not require CSRF"""
auth = 'Basic %s' % base64.encodestring('%s:%s' % (self.username, self.password)).strip() auth = 'Basic %s' % base64.encodestring(
response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json', '%s:%s' % (self.username, self.password)).strip()
response = self.csrf_client.post(
'/basic/', json.dumps({'example': 'example'}), 'application/json',
HTTP_AUTHORIZATION=auth) HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -60,7 +69,8 @@ class BasicAuthTests(TestCase):
def test_post_json_failing_basic_auth(self): def test_post_json_failing_basic_auth(self):
"""Ensure POSTing json over basic auth without correct credentials fails""" """Ensure POSTing json over basic auth without correct credentials fails"""
response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json') response = self.csrf_client.post('/basic/', json.dumps(
{'example': 'example'}), 'application/json')
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
self.assertEqual(response['WWW-Authenticate'], 'Basic realm="api"') self.assertEqual(response['WWW-Authenticate'], 'Basic realm="api"')
@ -75,7 +85,8 @@ class SessionAuthTests(TestCase):
self.username = 'john' self.username = 'john'
self.email = 'lennon@thebeatles.com' self.email = 'lennon@thebeatles.com'
self.password = 'password' self.password = 'password'
self.user = User.objects.create_user(self.username, self.email, self.password) self.user = User.objects.create_user(
self.username, self.email, self.password)
def tearDown(self): def tearDown(self):
self.csrf_client.logout() self.csrf_client.logout()
@ -92,16 +103,20 @@ class SessionAuthTests(TestCase):
""" """
Ensure POSTing form over session authentication with logged in user and CSRF token passes. Ensure POSTing form over session authentication with logged in user and CSRF token passes.
""" """
self.non_csrf_client.login(username=self.username, password=self.password) self.non_csrf_client.login(
response = self.non_csrf_client.post('/session/', {'example': 'example'}) username=self.username, password=self.password)
response = self.non_csrf_client.post(
'/session/', {'example': 'example'})
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_put_form_session_auth_passing(self): def test_put_form_session_auth_passing(self):
""" """
Ensure PUTting form over session authentication with logged in user and CSRF token passes. Ensure PUTting form over session authentication with logged in user and CSRF token passes.
""" """
self.non_csrf_client.login(username=self.username, password=self.password) self.non_csrf_client.login(
response = self.non_csrf_client.put('/session/', {'example': 'example'}) username=self.username, password=self.password)
response = self.non_csrf_client.put(
'/session/', {'example': 'example'})
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_post_form_session_auth_failing(self): def test_post_form_session_auth_failing(self):
@ -121,7 +136,8 @@ class TokenAuthTests(TestCase):
self.username = 'john' self.username = 'john'
self.email = 'lennon@thebeatles.com' self.email = 'lennon@thebeatles.com'
self.password = 'password' self.password = 'password'
self.user = User.objects.create_user(self.username, self.email, self.password) self.user = User.objects.create_user(
self.username, self.email, self.password)
self.key = 'abcd1234' self.key = 'abcd1234'
self.token = Token.objects.create(key=self.key, user=self.user) self.token = Token.objects.create(key=self.key, user=self.user)
@ -129,13 +145,15 @@ class TokenAuthTests(TestCase):
def test_post_form_passing_token_auth(self): def test_post_form_passing_token_auth(self):
"""Ensure POSTing json over token auth with correct credentials passes and does not require CSRF""" """Ensure POSTing json over token auth with correct credentials passes and does not require CSRF"""
auth = "Token " + self.key auth = "Token " + self.key
response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) response = self.csrf_client.post(
'/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_post_json_passing_token_auth(self): def test_post_json_passing_token_auth(self):
"""Ensure POSTing form over token auth with correct credentials passes and does not require CSRF""" """Ensure POSTing form over token auth with correct credentials passes and does not require CSRF"""
auth = "Token " + self.key auth = "Token " + self.key
response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json', response = self.csrf_client.post(
'/token/', json.dumps({'example': 'example'}), 'application/json',
HTTP_AUTHORIZATION=auth) HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -146,7 +164,8 @@ class TokenAuthTests(TestCase):
def test_post_json_failing_token_auth(self): def test_post_json_failing_token_auth(self):
"""Ensure POSTing json over token auth without correct credentials fails""" """Ensure POSTing json over token auth without correct credentials fails"""
response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json') response = self.csrf_client.post('/token/', json.dumps(
{'example': 'example'}), 'application/json')
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
def test_token_has_auto_assigned_key_if_none_provided(self): def test_token_has_auto_assigned_key_if_none_provided(self):

View File

@ -26,9 +26,12 @@ class NestedResourceInstance(APIView):
urlpatterns = patterns('', urlpatterns = patterns('',
url(r'^$', Root.as_view()), url(r'^$', Root.as_view()),
url(r'^resource/$', ResourceRoot.as_view()), url(r'^resource/$', ResourceRoot.as_view()),
url(r'^resource/(?P<key>[0-9]+)$', ResourceInstance.as_view()), url(r'^resource/(?P<key>[0-9]+)$',
url(r'^resource/(?P<key>[0-9]+)/$', NestedResourceRoot.as_view()), ResourceInstance.as_view()),
url(r'^resource/(?P<key>[0-9]+)/(?P<other>[A-Za-z]+)$', NestedResourceInstance.as_view()), url(r'^resource/(?P<key>[0-9]+)/$',
NestedResourceRoot.as_view()),
url(r'^resource/(?P<key>[0-9]+)/(?P<other>[A-Za-z]+)$',
NestedResourceInstance.as_view()),
) )
@ -49,22 +52,28 @@ class BreadcrumbTests(TestCase):
def test_resource_instance_breadcrumbs(self): def test_resource_instance_breadcrumbs(self):
url = '/resource/123' url = '/resource/123'
self.assertEqual(get_breadcrumbs(url), [('Root', '/'), self.assertEqual(get_breadcrumbs(url), [('Root', '/'),
('Resource Root', '/resource/'), ('Resource Root',
'/resource/'),
('Resource Instance', '/resource/123')]) ('Resource Instance', '/resource/123')])
def test_nested_resource_breadcrumbs(self): def test_nested_resource_breadcrumbs(self):
url = '/resource/123/' url = '/resource/123/'
self.assertEqual(get_breadcrumbs(url), [('Root', '/'), self.assertEqual(get_breadcrumbs(url), [('Root', '/'),
('Resource Root', '/resource/'), ('Resource Root',
('Resource Instance', '/resource/123'), '/resource/'),
('Resource Instance',
'/resource/123'),
('Nested Resource Root', '/resource/123/')]) ('Nested Resource Root', '/resource/123/')])
def test_nested_resource_instance_breadcrumbs(self): def test_nested_resource_instance_breadcrumbs(self):
url = '/resource/123/abc' url = '/resource/123/abc'
self.assertEqual(get_breadcrumbs(url), [('Root', '/'), self.assertEqual(get_breadcrumbs(url), [('Root', '/'),
('Resource Root', '/resource/'), ('Resource Root',
('Resource Instance', '/resource/123'), '/resource/'),
('Nested Resource Root', '/resource/123/'), ('Resource Instance',
'/resource/123'),
('Nested Resource Root',
'/resource/123/'),
('Nested Resource Instance', '/resource/123/abc')]) ('Nested Resource Instance', '/resource/123/abc')])
def test_broken_url_breadcrumbs_handled_gracefully(self): def test_broken_url_breadcrumbs_handled_gracefully(self):

View File

@ -60,7 +60,8 @@ class DecoratorTestCase(TestCase):
request = self.factory.post('/') request = self.factory.post('/')
response = view(request) response = view(request)
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) self.assertEqual(
response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
def test_calling_put_method(self): def test_calling_put_method(self):
@api_view(['GET', 'PUT']) @api_view(['GET', 'PUT'])
@ -73,7 +74,8 @@ class DecoratorTestCase(TestCase):
request = self.factory.post('/') request = self.factory.post('/')
response = view(request) response = view(request)
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) self.assertEqual(
response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
def test_calling_patch_method(self): def test_calling_patch_method(self):
@api_view(['GET', 'PATCH']) @api_view(['GET', 'PATCH'])
@ -86,7 +88,8 @@ class DecoratorTestCase(TestCase):
request = self.factory.post('/') request = self.factory.post('/')
response = view(request) response = view(request)
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) self.assertEqual(
response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
def test_renderer_classes(self): def test_renderer_classes(self):
@api_view(['GET']) @api_view(['GET'])
@ -146,4 +149,5 @@ class DecoratorTestCase(TestCase):
self.assertEquals(response.status_code, status.HTTP_200_OK) self.assertEquals(response.status_code, status.HTTP_200_OK)
response = view(request) response = view(request)
self.assertEquals(response.status_code, status.HTTP_429_TOO_MANY_REQUESTS) self.assertEquals(
response.status_code, status.HTTP_429_TOO_MANY_REQUESTS)

View File

@ -30,7 +30,8 @@ class FileSerializerTests(TestCase):
file = StringIO.StringIO('stuff') file = StringIO.StringIO('stuff')
file.name = 'stuff.txt' file.name = 'stuff.txt'
file.size = file.len file.size = file.len
serializer = UploadedFileSerializer(data={'created': now}, files={'file': file}) serializer = UploadedFileSerializer(
data={'created': now}, files={'file': file})
uploaded_file = UploadedFile(file=file, created=now) uploaded_file = UploadedFile(file=file, created=now)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
self.assertEquals(serializer.object.created, uploaded_file.created) self.assertEquals(serializer.object.created, uploaded_file.created)

View File

@ -56,14 +56,16 @@ class IntegrationTestFiltering(TestCase):
""" """
base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8)) base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8))
for i in range(10): for i in range(10):
text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc. text = chr(i + ord(
base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc.
decimal = base_data[1] + i decimal = base_data[1] + i
date = base_data[2] - datetime.timedelta(days=i * 2) date = base_data[2] - datetime.timedelta(days=i * 2)
FilterableItem(text=text, decimal=decimal, date=date).save() FilterableItem(text=text, decimal=decimal, date=date).save()
self.objects = FilterableItem.objects self.objects = FilterableItem.objects
self.data = [ self.data = [
{'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal,
'date': obj.date}
for obj in self.objects.all() for obj in self.objects.all()
] ]
@ -85,12 +87,14 @@ class IntegrationTestFiltering(TestCase):
request = factory.get('/?decimal=%s' % search_decimal) request = factory.get('/?decimal=%s' % search_decimal)
response = view(request).render() response = view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK) self.assertEquals(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['decimal'] == search_decimal] expected_data = [f for f in self.data if f['decimal']
== search_decimal]
self.assertEquals(response.data, expected_data) self.assertEquals(response.data, expected_data)
# Tests that the date filter works. # Tests that the date filter works.
search_date = datetime.date(2012, 9, 22) search_date = datetime.date(2012, 9, 22)
request = factory.get('/?date=%s' % search_date) # search_date str: '2012-09-22' request = factory.get(
'/?date=%s' % search_date) # search_date str: '2012-09-22'
response = view(request).render() response = view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK) self.assertEquals(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['date'] == search_date] expected_data = [f for f in self.data if f['date'] == search_date]
@ -110,7 +114,8 @@ class IntegrationTestFiltering(TestCase):
self.assertEquals(response.status_code, status.HTTP_200_OK) self.assertEquals(response.status_code, status.HTTP_200_OK)
self.assertEquals(response.data, self.data) self.assertEquals(response.data, self.data)
# Tests that the decimal filter set with 'lt' in the filter class works. # Tests that the decimal filter set with 'lt' in the filter class
# works.
search_decimal = Decimal('4.25') search_decimal = Decimal('4.25')
request = factory.get('/?decimal=%s' % search_decimal) request = factory.get('/?decimal=%s' % search_decimal)
response = view(request).render() response = view(request).render()
@ -120,24 +125,28 @@ class IntegrationTestFiltering(TestCase):
# Tests that the date filter set with 'gt' in the filter class works. # Tests that the date filter set with 'gt' in the filter class works.
search_date = datetime.date(2012, 10, 2) search_date = datetime.date(2012, 10, 2)
request = factory.get('/?date=%s' % search_date) # search_date str: '2012-10-02' request = factory.get(
'/?date=%s' % search_date) # search_date str: '2012-10-02'
response = view(request).render() response = view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK) self.assertEquals(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['date'] > search_date] expected_data = [f for f in self.data if f['date'] > search_date]
self.assertEquals(response.data, expected_data) self.assertEquals(response.data, expected_data)
# Tests that the text filter set with 'icontains' in the filter class works. # Tests that the text filter set with 'icontains' in the filter class
# works.
search_text = 'ff' search_text = 'ff'
request = factory.get('/?text=%s' % search_text) request = factory.get('/?text=%s' % search_text)
response = view(request).render() response = view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK) self.assertEquals(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if search_text in f['text'].lower()] expected_data = [f for f in self.data if search_text in f[
'text'].lower()]
self.assertEquals(response.data, expected_data) self.assertEquals(response.data, expected_data)
# Tests that multiple filters works. # Tests that multiple filters works.
search_decimal = Decimal('5.25') search_decimal = Decimal('5.25')
search_date = datetime.date(2012, 10, 2) search_date = datetime.date(2012, 10, 2)
request = factory.get('/?decimal=%s&date=%s' % (search_decimal, search_date)) request = factory.get(
'/?decimal=%s&date=%s' % (search_decimal, search_date))
response = view(request).render() response = view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK) self.assertEquals(response.status_code, status.HTTP_200_OK)
expected_data = [f for f in self.data if f['date'] > search_date and expected_data = [f for f in self.data if f['date'] > search_date and

View File

@ -84,8 +84,10 @@ class TestRootView(TestCase):
request = factory.put('/', json.dumps(content), request = factory.put('/', json.dumps(content),
content_type='application/json') content_type='application/json')
response = self.view(request).render() response = self.view(request).render()
self.assertEquals(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) self.assertEquals(
self.assertEquals(response.data, {"detail": "Method 'PUT' not allowed."}) response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
self.assertEquals(
response.data, {"detail": "Method 'PUT' not allowed."})
def test_delete_root_view(self): def test_delete_root_view(self):
""" """
@ -93,8 +95,10 @@ class TestRootView(TestCase):
""" """
request = factory.delete('/') request = factory.delete('/')
response = self.view(request).render() response = self.view(request).render()
self.assertEquals(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) self.assertEquals(
self.assertEquals(response.data, {"detail": "Method 'DELETE' not allowed."}) response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
self.assertEquals(
response.data, {"detail": "Method 'DELETE' not allowed."})
def test_options_root_view(self): def test_options_root_view(self):
""" """
@ -165,8 +169,10 @@ class TestInstanceView(TestCase):
request = factory.post('/', json.dumps(content), request = factory.post('/', json.dumps(content),
content_type='application/json') content_type='application/json')
response = self.view(request).render() response = self.view(request).render()
self.assertEquals(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) self.assertEquals(
self.assertEquals(response.data, {"detail": "Method 'POST' not allowed."}) response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
self.assertEquals(
response.data, {"detail": "Method 'POST' not allowed."})
def test_put_instance_view(self): def test_put_instance_view(self):
""" """
@ -262,7 +268,8 @@ class TestInstanceView(TestCase):
at the requested url if it doesn't exist. at the requested url if it doesn't exist.
""" """
content = {'text': 'foobar'} content = {'text': 'foobar'}
# pk fields can not be created on demand, only the database can set th pk for a new object # pk fields can not be created on demand, only the database can set th
# pk for a new object
request = factory.put('/5', json.dumps(content), request = factory.put('/5', json.dumps(content),
content_type='application/json') content_type='application/json')
response = self.view(request, pk=5).render() response = self.view(request, pk=5).render()
@ -280,7 +287,8 @@ class TestInstanceView(TestCase):
content_type='application/json') content_type='application/json')
response = self.slug_based_view(request, slug='test_slug').render() response = self.slug_based_view(request, slug='test_slug').render()
self.assertEquals(response.status_code, status.HTTP_201_CREATED) self.assertEquals(response.status_code, status.HTTP_201_CREATED)
self.assertEquals(response.data, {'slug': 'test_slug', 'text': 'foobar'}) self.assertEquals(
response.data, {'slug': 'test_slug', 'text': 'foobar'})
new_obj = SlugBasedModel.objects.get(slug='test_slug') new_obj = SlugBasedModel.objects.get(slug='test_slug')
self.assertEquals(new_obj.text, 'foobar') self.assertEquals(new_obj.text, 'foobar')

View File

@ -9,9 +9,11 @@ factory = RequestFactory()
class BlogPostCommentSerializer(serializers.ModelSerializer): class BlogPostCommentSerializer(serializers.ModelSerializer):
url = serializers.HyperlinkedIdentityField(view_name='blogpostcomment-detail') url = serializers.HyperlinkedIdentityField(
view_name='blogpostcomment-detail')
text = serializers.CharField() text = serializers.CharField()
blog_post_url = serializers.HyperlinkedRelatedField(source='blog_post', view_name='blogpost-detail') blog_post_url = serializers.HyperlinkedRelatedField(
source='blog_post', view_name='blogpost-detail')
class Meta: class Meta:
model = BlogPostComment model = BlogPostComment
@ -20,7 +22,8 @@ class BlogPostCommentSerializer(serializers.ModelSerializer):
class PhotoSerializer(serializers.Serializer): class PhotoSerializer(serializers.Serializer):
description = serializers.CharField() description = serializers.CharField()
album_url = serializers.HyperlinkedRelatedField(source='album', view_name='album-detail', album_url = serializers.HyperlinkedRelatedField(
source='album', view_name='album-detail',
queryset=Album.objects.all(), slug_field='title', slug_url_kwarg='title') queryset=Album.objects.all(), slug_field='title', slug_url_kwarg='title')
def restore_object(self, attrs, instance=None): def restore_object(self, attrs, instance=None):
@ -81,17 +84,32 @@ class OptionalRelationDetail(generics.RetrieveUpdateDestroyAPIView):
urlpatterns = patterns('', urlpatterns = patterns('',
url(r'^basic/$', BasicList.as_view(), name='basicmodel-list'), url(r'^basic/$',
url(r'^basic/(?P<pk>\d+)/$', BasicDetail.as_view(), name='basicmodel-detail'), BasicList.as_view(), name='basicmodel-list'),
url(r'^anchor/(?P<pk>\d+)/$', AnchorDetail.as_view(), name='anchor-detail'), url(r'^basic/(?P<pk>\d+)/$', BasicDetail.as_view(),
url(r'^manytomany/$', ManyToManyList.as_view(), name='manytomanymodel-list'), name='basicmodel-detail'),
url(r'^manytomany/(?P<pk>\d+)/$', ManyToManyDetail.as_view(), name='manytomanymodel-detail'), url(r'^anchor/(?P<pk>\d+)/$', AnchorDetail.as_view(),
url(r'^posts/(?P<pk>\d+)/$', BlogPostDetail.as_view(), name='blogpost-detail'), name='anchor-detail'),
url(r'^comments/$', BlogPostCommentListCreate.as_view(), name='blogpostcomment-list'), url(r'^manytomany/$', ManyToManyList.as_view(),
url(r'^comments/(?P<pk>\d+)/$', BlogPostCommentDetail.as_view(), name='blogpostcomment-detail'), name='manytomanymodel-list'),
url(r'^albums/(?P<title>\w[\w-]*)/$', AlbumDetail.as_view(), name='album-detail'), url(
url(r'^photos/$', PhotoListCreate.as_view(), name='photo-list'), r'^manytomany/(?P<pk>\d+)/$', ManyToManyDetail.as_view(
url(r'^optionalrelation/(?P<pk>\d+)/$', OptionalRelationDetail.as_view(), name='optionalrelationmodel-detail'), ),
name='manytomanymodel-detail'),
url(r'^posts/(?P<pk>\d+)/$', BlogPostDetail.as_view(),
name='blogpost-detail'),
url(r'^comments/$', BlogPostCommentListCreate.as_view(),
name='blogpostcomment-list'),
url(
r'^comments/(?P<pk>\d+)/$', BlogPostCommentDetail.as_view(),
name='blogpostcomment-detail'),
url(
r'^albums/(?P<title>\w[\w-]*)/$', AlbumDetail.as_view(),
name='album-detail'),
url(r'^photos/$',
PhotoListCreate.as_view(), name='photo-list'),
url(r'^optionalrelation/(?P<pk>\d+)/$', OptionalRelationDetail.as_view(
), name='optionalrelationmodel-detail'),
) )
@ -201,7 +219,8 @@ class TestCreateWithForeignKeys(TestCase):
self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(response['Location'], 'http://testserver/comments/1/') self.assertEqual(response['Location'], 'http://testserver/comments/1/')
self.assertEqual(self.post.blogpostcomment_set.count(), 1) self.assertEqual(self.post.blogpostcomment_set.count(), 1)
self.assertEqual(self.post.blogpostcomment_set.all()[0].text, 'A test comment') self.assertEqual(
self.post.blogpostcomment_set.all()[0].text, 'A test comment')
class TestCreateWithForeignKeysAndCustomSlug(TestCase): class TestCreateWithForeignKeysAndCustomSlug(TestCase):
@ -226,7 +245,8 @@ class TestCreateWithForeignKeysAndCustomSlug(TestCase):
self.assertNotIn('Location', response, self.assertNotIn('Location', response,
msg='Location should only be included if there is a "url" field on the serializer') msg='Location should only be included if there is a "url" field on the serializer')
self.assertEqual(self.post.photo_set.count(), 1) self.assertEqual(self.post.photo_set.count(), 1)
self.assertEqual(self.post.photo_set.all()[0].description, 'A test photo') self.assertEqual(
self.post.photo_set.all()[0].description, 'A test photo')
class TestOptionalRelationHyperlinkedView(TestCase): class TestOptionalRelationHyperlinkedView(TestCase):
@ -239,7 +259,8 @@ class TestOptionalRelationHyperlinkedView(TestCase):
OptionalRelationModel().save() OptionalRelationModel().save()
self.objects = OptionalRelationModel.objects self.objects = OptionalRelationModel.objects
self.detail_view = OptionalRelationDetail.as_view() self.detail_view = OptionalRelationDetail.as_view()
self.data = {"url": "http://testserver/optionalrelation/1/", "other": None} self.data = {"url":
"http://testserver/optionalrelation/1/", "other": None}
def test_get_detail_view(self): def test_get_detail_view(self):
""" """

View File

@ -103,14 +103,16 @@ class IntegrationTestPaginationAndFiltering(TestCase):
""" """
base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8)) base_data = ('a', Decimal('0.25'), datetime.date(2012, 10, 8))
for i in range(26): for i in range(26):
text = chr(i + ord(base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc. text = chr(i + ord(
base_data[0])) * 3 # Produces string 'aaa', 'bbb', etc.
decimal = base_data[1] + i decimal = base_data[1] + i
date = base_data[2] - datetime.timedelta(days=i * 2) date = base_data[2] - datetime.timedelta(days=i * 2)
FilterableItem(text=text, decimal=decimal, date=date).save() FilterableItem(text=text, decimal=decimal, date=date).save()
self.objects = FilterableItem.objects self.objects = FilterableItem.objects
self.data = [ self.data = [
{'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal,
'date': obj.date}
for obj in self.objects.all() for obj in self.objects.all()
] ]
self.view = FilterFieldsRootView.as_view() self.view = FilterFieldsRootView.as_view()
@ -180,7 +182,8 @@ class UnitTestPagination(TestCase):
""" """
Ensure context gets passed through to the object serializer. Ensure context gets passed through to the object serializer.
""" """
serializer = PassOnContextPaginationSerializer(self.first_page, context={'foo': 'bar'}) serializer = PassOnContextPaginationSerializer(
self.first_page, context={'foo': 'bar'})
serializer.data serializer.data
results = serializer.fields[serializer.results_field] results = serializer.fields[serializer.results_field]
self.assertEquals(serializer.context, results.context) self.assertEquals(serializer.context, results.context)
@ -254,7 +257,8 @@ class TestCustomPaginateByParam(TestCase):
class CustomField(serializers.Field): class CustomField(serializers.Field):
def to_native(self, value): def to_native(self, value):
if not 'view' in self.context: if not 'view' in self.context:
raise RuntimeError("context isn't getting passed into custom field") raise RuntimeError(
"context isn't getting passed into custom field")
return "value" return "value"
@ -277,4 +281,3 @@ class TestContextPassedToCustomField(TestCase):
response = self.view(request).render() response = self.view(request).render()
self.assertEquals(response.status_code, status.HTTP_200_OK) self.assertEquals(response.status_code, status.HTTP_200_OK)

View File

@ -18,16 +18,19 @@ class FieldTests(TestCase):
https://github.com/tomchristie/django-rest-framework/issues/446 https://github.com/tomchristie/django-rest-framework/issues/446
""" """
field = serializers.PrimaryKeyRelatedField(queryset=NullModel.objects.all()) field = serializers.PrimaryKeyRelatedField(
queryset=NullModel.objects.all())
self.assertRaises(serializers.ValidationError, field.from_native, '') self.assertRaises(serializers.ValidationError, field.from_native, '')
self.assertRaises(serializers.ValidationError, field.from_native, []) self.assertRaises(serializers.ValidationError, field.from_native, [])
def test_hyperlinked_related_field_with_empty_string(self): def test_hyperlinked_related_field_with_empty_string(self):
field = serializers.HyperlinkedRelatedField(queryset=NullModel.objects.all(), view_name='') field = serializers.HyperlinkedRelatedField(
queryset=NullModel.objects.all(), view_name='')
self.assertRaises(serializers.ValidationError, field.from_native, '') self.assertRaises(serializers.ValidationError, field.from_native, '')
self.assertRaises(serializers.ValidationError, field.from_native, []) self.assertRaises(serializers.ValidationError, field.from_native, [])
def test_slug_related_field_with_empty_string(self): def test_slug_related_field_with_empty_string(self):
field = serializers.SlugRelatedField(queryset=NullModel.objects.all(), slug_field='pk') field = serializers.SlugRelatedField(
queryset=NullModel.objects.all(), slug_field='pk')
self.assertRaises(serializers.ValidationError, field.from_native, '') self.assertRaises(serializers.ValidationError, field.from_native, '')
self.assertRaises(serializers.ValidationError, field.from_native, []) self.assertRaises(serializers.ValidationError, field.from_native, [])

View File

@ -8,18 +8,28 @@ def dummy_view(request, pk):
pass pass
urlpatterns = patterns('', urlpatterns = patterns('',
url(r'^manytomanysource/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanysource-detail'), url(r'^manytomanysource/(?P<pk>[0-9]+)/$', dummy_view,
url(r'^manytomanytarget/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanytarget-detail'), name='manytomanysource-detail'),
url(r'^foreignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeysource-detail'), url(r'^manytomanytarget/(?P<pk>[0-9]+)/$', dummy_view,
url(r'^foreignkeytarget/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeytarget-detail'), name='manytomanytarget-detail'),
url(r'^nullableforeignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableforeignkeysource-detail'), url(r'^foreignkeysource/(?P<pk>[0-9]+)/$', dummy_view,
url(r'^onetoonetarget/(?P<pk>[0-9]+)/$', dummy_view, name='onetoonetarget-detail'), name='foreignkeysource-detail'),
url(r'^nullableonetoonesource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableonetoonesource-detail'), url(r'^foreignkeytarget/(?P<pk>[0-9]+)/$', dummy_view,
name='foreignkeytarget-detail'),
url(
r'^nullableforeignkeysource/(?P<pk>[0-9]+)/$', dummy_view,
name='nullableforeignkeysource-detail'),
url(r'^onetoonetarget/(?P<pk>[0-9]+)/$', dummy_view,
name='onetoonetarget-detail'),
url(
r'^nullableonetoonesource/(?P<pk>[0-9]+)/$', dummy_view,
name='nullableonetoonesource-detail'),
) )
class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer): class ManyToManyTargetSerializer(serializers.HyperlinkedModelSerializer):
sources = serializers.ManyHyperlinkedRelatedField(view_name='manytomanysource-detail') sources = serializers.ManyHyperlinkedRelatedField(
view_name='manytomanysource-detail')
class Meta: class Meta:
model = ManyToManyTarget model = ManyToManyTarget
@ -31,7 +41,8 @@ class ManyToManySourceSerializer(serializers.HyperlinkedModelSerializer):
class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer): class ForeignKeyTargetSerializer(serializers.HyperlinkedModelSerializer):
sources = serializers.ManyHyperlinkedRelatedField(view_name='foreignkeysource-detail') sources = serializers.ManyHyperlinkedRelatedField(
view_name='foreignkeysource-detail')
class Meta: class Meta:
model = ForeignKeyTarget model = ForeignKeyTarget
@ -50,7 +61,8 @@ class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer)
# OneToOne # OneToOne
class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer): class NullableOneToOneTargetSerializer(serializers.HyperlinkedModelSerializer):
nullable_source = serializers.HyperlinkedRelatedField(view_name='nullableonetoonesource-detail') nullable_source = serializers.HyperlinkedRelatedField(
view_name='nullableonetoonesource-detail')
class Meta: class Meta:
model = OneToOneTarget model = OneToOneTarget
@ -74,7 +86,8 @@ class HyperlinkedManyToManyTests(TestCase):
queryset = ManyToManySource.objects.all() queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset) serializer = ManyToManySourceSerializer(queryset)
expected = [ expected = [
{'url': '/manytomanysource/1/', 'name': u'source-1', 'targets': ['/manytomanytarget/1/']}, {'url': '/manytomanysource/1/', 'name':
u'source-1', 'targets': ['/manytomanytarget/1/']},
{'url': '/manytomanysource/2/', 'name': u'source-2', {'url': '/manytomanysource/2/', 'name': u'source-2',
'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']}, 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']},
{'url': '/manytomanysource/3/', 'name': u'source-3', {'url': '/manytomanysource/3/', 'name': u'source-3',
@ -90,7 +103,8 @@ class HyperlinkedManyToManyTests(TestCase):
'sources': ['/manytomanysource/1/', '/manytomanysource/2/', '/manytomanysource/3/']}, 'sources': ['/manytomanysource/1/', '/manytomanysource/2/', '/manytomanysource/3/']},
{'url': '/manytomanytarget/2/', 'name': u'target-2', {'url': '/manytomanytarget/2/', 'name': u'target-2',
'sources': ['/manytomanysource/2/', '/manytomanysource/3/']}, 'sources': ['/manytomanysource/2/', '/manytomanysource/3/']},
{'url': '/manytomanytarget/3/', 'name': u'target-3', 'sources': ['/manytomanysource/3/']} {'url': '/manytomanytarget/3/', 'name':
u'target-3', 'sources': ['/manytomanysource/3/']}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -117,7 +131,8 @@ class HyperlinkedManyToManyTests(TestCase):
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_reverse_many_to_many_update(self): def test_reverse_many_to_many_update(self):
data = {'url': '/manytomanytarget/1/', 'name': u'target-1', 'sources': ['/manytomanysource/1/']} data = {'url': '/manytomanytarget/1/', 'name': u'target-1',
'sources': ['/manytomanysource/1/']}
instance = ManyToManyTarget.objects.get(pk=1) instance = ManyToManyTarget.objects.get(pk=1)
serializer = ManyToManyTargetSerializer(instance, data=data) serializer = ManyToManyTargetSerializer(instance, data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
@ -128,10 +143,12 @@ class HyperlinkedManyToManyTests(TestCase):
queryset = ManyToManyTarget.objects.all() queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset) serializer = ManyToManyTargetSerializer(queryset)
expected = [ expected = [
{'url': '/manytomanytarget/1/', 'name': u'target-1', 'sources': ['/manytomanysource/1/']}, {'url': '/manytomanytarget/1/', 'name':
u'target-1', 'sources': ['/manytomanysource/1/']},
{'url': '/manytomanytarget/2/', 'name': u'target-2', {'url': '/manytomanytarget/2/', 'name': u'target-2',
'sources': ['/manytomanysource/2/', '/manytomanysource/3/']}, 'sources': ['/manytomanysource/2/', '/manytomanysource/3/']},
{'url': '/manytomanytarget/3/', 'name': u'target-3', 'sources': ['/manytomanysource/3/']} {'url': '/manytomanytarget/3/', 'name':
u'target-3', 'sources': ['/manytomanysource/3/']}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -149,7 +166,8 @@ class HyperlinkedManyToManyTests(TestCase):
queryset = ManyToManySource.objects.all() queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset) serializer = ManyToManySourceSerializer(queryset)
expected = [ expected = [
{'url': '/manytomanysource/1/', 'name': u'source-1', 'targets': ['/manytomanytarget/1/']}, {'url': '/manytomanysource/1/', 'name':
u'source-1', 'targets': ['/manytomanytarget/1/']},
{'url': '/manytomanysource/2/', 'name': u'source-2', {'url': '/manytomanysource/2/', 'name': u'source-2',
'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']}, 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']},
{'url': '/manytomanysource/3/', 'name': u'source-3', {'url': '/manytomanysource/3/', 'name': u'source-3',
@ -176,7 +194,8 @@ class HyperlinkedManyToManyTests(TestCase):
'sources': ['/manytomanysource/1/', '/manytomanysource/2/', '/manytomanysource/3/']}, 'sources': ['/manytomanysource/1/', '/manytomanysource/2/', '/manytomanysource/3/']},
{'url': '/manytomanytarget/2/', 'name': u'target-2', {'url': '/manytomanytarget/2/', 'name': u'target-2',
'sources': ['/manytomanysource/2/', '/manytomanysource/3/']}, 'sources': ['/manytomanysource/2/', '/manytomanysource/3/']},
{'url': '/manytomanytarget/3/', 'name': u'target-3', 'sources': ['/manytomanysource/3/']}, {'url': '/manytomanytarget/3/', 'name':
u'target-3', 'sources': ['/manytomanysource/3/']},
{'url': '/manytomanytarget/4/', 'name': u'target-4', {'url': '/manytomanytarget/4/', 'name': u'target-4',
'sources': ['/manytomanysource/1/', '/manytomanysource/3/']} 'sources': ['/manytomanysource/1/', '/manytomanysource/3/']}
] ]
@ -199,9 +218,12 @@ class HyperlinkedForeignKeyTests(TestCase):
queryset = ForeignKeySource.objects.all() queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset) serializer = ForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'}, {'url': '/foreignkeysource/1/', 'name':
{'url': '/foreignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, u'source-1', 'target': '/foreignkeytarget/1/'},
{'url': '/foreignkeysource/3/', 'name': u'source-3', 'target': '/foreignkeytarget/1/'} {'url': '/foreignkeysource/2/', 'name':
u'source-2', 'target': '/foreignkeytarget/1/'},
{'url': '/foreignkeysource/3/', 'name':
u'source-3', 'target': '/foreignkeytarget/1/'}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -211,12 +233,14 @@ class HyperlinkedForeignKeyTests(TestCase):
expected = [ expected = [
{'url': '/foreignkeytarget/1/', 'name': u'target-1', {'url': '/foreignkeytarget/1/', 'name': u'target-1',
'sources': ['/foreignkeysource/1/', '/foreignkeysource/2/', '/foreignkeysource/3/']}, 'sources': ['/foreignkeysource/1/', '/foreignkeysource/2/', '/foreignkeysource/3/']},
{'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': []}, {'url': '/foreignkeytarget/2/', 'name':
u'target-2', 'sources': []},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_foreign_key_update(self): def test_foreign_key_update(self):
data = {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/2/'} data = {'url': '/foreignkeysource/1/', 'name': u'source-1',
'target': '/foreignkeytarget/2/'}
instance = ForeignKeySource.objects.get(pk=1) instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data) serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
@ -227,14 +251,18 @@ class HyperlinkedForeignKeyTests(TestCase):
queryset = ForeignKeySource.objects.all() queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset) serializer = ForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/2/'}, {'url': '/foreignkeysource/1/', 'name':
{'url': '/foreignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, u'source-1', 'target': '/foreignkeytarget/2/'},
{'url': '/foreignkeysource/3/', 'name': u'source-3', 'target': '/foreignkeytarget/1/'} {'url': '/foreignkeysource/2/', 'name':
u'source-2', 'target': '/foreignkeytarget/1/'},
{'url': '/foreignkeysource/3/', 'name':
u'source-3', 'target': '/foreignkeytarget/1/'}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_foreign_key_update_incorrect_type(self): def test_foreign_key_update_incorrect_type(self):
data = {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': 2} data = {'url': '/foreignkeysource/1/', 'name': u'source-1',
'target': 2}
instance = ForeignKeySource.objects.get(pk=1) instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data) serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
@ -253,7 +281,8 @@ class HyperlinkedForeignKeyTests(TestCase):
expected = [ expected = [
{'url': '/foreignkeytarget/1/', 'name': u'target-1', {'url': '/foreignkeytarget/1/', 'name': u'target-1',
'sources': ['/foreignkeysource/1/', '/foreignkeysource/2/', '/foreignkeysource/3/']}, 'sources': ['/foreignkeysource/1/', '/foreignkeysource/2/', '/foreignkeysource/3/']},
{'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': []}, {'url': '/foreignkeytarget/2/', 'name':
u'target-2', 'sources': []},
] ]
self.assertEquals(new_serializer.data, expected) self.assertEquals(new_serializer.data, expected)
@ -264,14 +293,16 @@ class HyperlinkedForeignKeyTests(TestCase):
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset) serializer = ForeignKeyTargetSerializer(queryset)
expected = [ expected = [
{'url': '/foreignkeytarget/1/', 'name': u'target-1', 'sources': ['/foreignkeysource/2/']}, {'url': '/foreignkeytarget/1/', 'name':
u'target-1', 'sources': ['/foreignkeysource/2/']},
{'url': '/foreignkeytarget/2/', 'name': u'target-2', {'url': '/foreignkeytarget/2/', 'name': u'target-2',
'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']}, 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_foreign_key_create(self): def test_foreign_key_create(self):
data = {'url': '/foreignkeysource/4/', 'name': u'source-4', 'target': '/foreignkeytarget/2/'} data = {'url': '/foreignkeysource/4/', 'name': u'source-4',
'target': '/foreignkeytarget/2/'}
serializer = ForeignKeySourceSerializer(data=data) serializer = ForeignKeySourceSerializer(data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
obj = serializer.save() obj = serializer.save()
@ -282,10 +313,14 @@ class HyperlinkedForeignKeyTests(TestCase):
queryset = ForeignKeySource.objects.all() queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset) serializer = ForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'}, {'url': '/foreignkeysource/1/', 'name':
{'url': '/foreignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, u'source-1', 'target': '/foreignkeytarget/1/'},
{'url': '/foreignkeysource/3/', 'name': u'source-3', 'target': '/foreignkeytarget/1/'}, {'url': '/foreignkeysource/2/', 'name':
{'url': '/foreignkeysource/4/', 'name': u'source-4', 'target': '/foreignkeytarget/2/'}, u'source-2', 'target': '/foreignkeytarget/1/'},
{'url': '/foreignkeysource/3/', 'name':
u'source-3', 'target': '/foreignkeytarget/1/'},
{'url': '/foreignkeysource/4/', 'name':
u'source-4', 'target': '/foreignkeytarget/2/'},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -302,19 +337,23 @@ class HyperlinkedForeignKeyTests(TestCase):
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset) serializer = ForeignKeyTargetSerializer(queryset)
expected = [ expected = [
{'url': '/foreignkeytarget/1/', 'name': u'target-1', 'sources': ['/foreignkeysource/2/']}, {'url': '/foreignkeytarget/1/', 'name':
{'url': '/foreignkeytarget/2/', 'name': u'target-2', 'sources': []}, u'target-1', 'sources': ['/foreignkeysource/2/']},
{'url': '/foreignkeytarget/2/', 'name':
u'target-2', 'sources': []},
{'url': '/foreignkeytarget/3/', 'name': u'target-3', {'url': '/foreignkeytarget/3/', 'name': u'target-3',
'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']}, 'sources': ['/foreignkeysource/1/', '/foreignkeysource/3/']},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_foreign_key_update_with_invalid_null(self): def test_foreign_key_update_with_invalid_null(self):
data = {'url': '/foreignkeysource/1/', 'name': u'source-1', 'target': None} data = {'url': '/foreignkeysource/1/', 'name': u'source-1',
'target': None}
instance = ForeignKeySource.objects.get(pk=1) instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data) serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'target': [u'Value may not be null']}) self.assertEquals(
serializer.errors, {'target': [u'Value may not be null']})
class HyperlinkedNullableForeignKeyTests(TestCase): class HyperlinkedNullableForeignKeyTests(TestCase):
@ -326,21 +365,26 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
for idx in range(1, 4): for idx in range(1, 4):
if idx == 3: if idx == 3:
target = None target = None
source = NullableForeignKeySource(name='source-%d' % idx, target=target) source = NullableForeignKeySource(
name='source-%d' % idx, target=target)
source.save() source.save()
def test_foreign_key_retrieve_with_null(self): def test_foreign_key_retrieve_with_null(self):
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset) serializer = NullableForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'}, {'url': '/nullableforeignkeysource/1/', 'name':
{'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, u'source-1', 'target': '/foreignkeytarget/1/'},
{'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None}, {'url': '/nullableforeignkeysource/2/', 'name':
u'source-2', 'target': '/foreignkeytarget/1/'},
{'url': '/nullableforeignkeysource/3/', 'name':
u'source-3', 'target': None},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_foreign_key_create_with_valid_null(self): def test_foreign_key_create_with_valid_null(self):
data = {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None} data = {'url': '/nullableforeignkeysource/4/', 'name':
u'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data) serializer = NullableForeignKeySourceSerializer(data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
obj = serializer.save() obj = serializer.save()
@ -351,10 +395,14 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset) serializer = NullableForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'}, {'url': '/nullableforeignkeysource/1/', 'name':
{'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, u'source-1', 'target': '/foreignkeytarget/1/'},
{'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None}, {'url': '/nullableforeignkeysource/2/', 'name':
{'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None} u'source-2', 'target': '/foreignkeytarget/1/'},
{'url': '/nullableforeignkeysource/3/', 'name':
u'source-3', 'target': None},
{'url': '/nullableforeignkeysource/4/', 'name':
u'source-4', 'target': None}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -363,8 +411,10 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
The emptystring should be interpreted as null in the context The emptystring should be interpreted as null in the context
of relationships. of relationships.
""" """
data = {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': ''} data = {'url': '/nullableforeignkeysource/4/', 'name':
expected_data = {'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None} u'source-4', 'target': ''}
expected_data = {'url': '/nullableforeignkeysource/4/',
'name': u'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data) serializer = NullableForeignKeySourceSerializer(data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
obj = serializer.save() obj = serializer.save()
@ -375,15 +425,20 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset) serializer = NullableForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': '/foreignkeytarget/1/'}, {'url': '/nullableforeignkeysource/1/', 'name':
{'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, u'source-1', 'target': '/foreignkeytarget/1/'},
{'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None}, {'url': '/nullableforeignkeysource/2/', 'name':
{'url': '/nullableforeignkeysource/4/', 'name': u'source-4', 'target': None} u'source-2', 'target': '/foreignkeytarget/1/'},
{'url': '/nullableforeignkeysource/3/', 'name':
u'source-3', 'target': None},
{'url': '/nullableforeignkeysource/4/', 'name':
u'source-4', 'target': None}
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_foreign_key_update_with_valid_null(self): def test_foreign_key_update_with_valid_null(self):
data = {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None} data = {'url': '/nullableforeignkeysource/1/', 'name':
u'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1) instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data) serializer = NullableForeignKeySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
@ -394,9 +449,12 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset) serializer = NullableForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None}, {'url': '/nullableforeignkeysource/1/', 'name':
{'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, u'source-1', 'target': None},
{'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None}, {'url': '/nullableforeignkeysource/2/', 'name':
u'source-2', 'target': '/foreignkeytarget/1/'},
{'url': '/nullableforeignkeysource/3/', 'name':
u'source-3', 'target': None},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -405,8 +463,10 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
The emptystring should be interpreted as null in the context The emptystring should be interpreted as null in the context
of relationships. of relationships.
""" """
data = {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': ''} data = {'url': '/nullableforeignkeysource/1/', 'name':
expected_data = {'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None} u'source-1', 'target': ''}
expected_data = {'url': '/nullableforeignkeysource/1/',
'name': u'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1) instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data) serializer = NullableForeignKeySourceSerializer(instance, data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
@ -417,9 +477,12 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset) serializer = NullableForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'url': '/nullableforeignkeysource/1/', 'name': u'source-1', 'target': None}, {'url': '/nullableforeignkeysource/1/', 'name':
{'url': '/nullableforeignkeysource/2/', 'name': u'source-2', 'target': '/foreignkeytarget/1/'}, u'source-1', 'target': None},
{'url': '/nullableforeignkeysource/3/', 'name': u'source-3', 'target': None}, {'url': '/nullableforeignkeysource/2/', 'name':
u'source-2', 'target': '/foreignkeytarget/1/'},
{'url': '/nullableforeignkeysource/3/', 'name':
u'source-3', 'target': None},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -460,7 +523,9 @@ class HyperlinkedNullableOneToOneTests(TestCase):
queryset = OneToOneTarget.objects.all() queryset = OneToOneTarget.objects.all()
serializer = NullableOneToOneTargetSerializer(queryset) serializer = NullableOneToOneTargetSerializer(queryset)
expected = [ expected = [
{'url': '/onetoonetarget/1/', 'name': u'target-1', 'nullable_source': '/nullableonetoonesource/1/'}, {'url': '/onetoonetarget/1/', 'name': u'target-1',
{'url': '/onetoonetarget/2/', 'name': u'target-2', 'nullable_source': None}, 'nullable_source': '/nullableonetoonesource/1/'},
{'url': '/onetoonetarget/2/', 'name': u'target-2',
'nullable_source': None},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)

View File

@ -53,9 +53,12 @@ class ReverseForeignKeyTests(TestCase):
queryset = ForeignKeySource.objects.all() queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset) serializer = ForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'source-1', 'target': {'id': 1, 'name': u'target-1'}}, {'id': 1, 'name': u'source-1', 'target': {'id': 1,
{'id': 2, 'name': u'source-2', 'target': {'id': 1, 'name': u'target-1'}}, 'name': u'target-1'}},
{'id': 3, 'name': u'source-3', 'target': {'id': 1, 'name': u'target-1'}}, {'id': 2, 'name': u'source-2', 'target': {'id': 1,
'name': u'target-1'}},
{'id': 3, 'name': u'source-3', 'target': {'id': 1,
'name': u'target-1'}},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -81,15 +84,18 @@ class NestedNullableForeignKeyTests(TestCase):
for idx in range(1, 4): for idx in range(1, 4):
if idx == 3: if idx == 3:
target = None target = None
source = NullableForeignKeySource(name='source-%d' % idx, target=target) source = NullableForeignKeySource(
name='source-%d' % idx, target=target)
source.save() source.save()
def test_foreign_key_retrieve_with_null(self): def test_foreign_key_retrieve_with_null(self):
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset) serializer = NullableForeignKeySourceSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'source-1', 'target': {'id': 1, 'name': u'target-1'}}, {'id': 1, 'name': u'source-1', 'target': {'id': 1,
{'id': 2, 'name': u'source-2', 'target': {'id': 1, 'name': u'target-1'}}, 'name': u'target-1'}},
{'id': 2, 'name': u'source-2', 'target': {'id': 1,
'name': u'target-1'}},
{'id': 3, 'name': u'source-3', 'target': None}, {'id': 3, 'name': u'source-3', 'target': None},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -108,7 +114,8 @@ class NestedNullableOneToOneTests(TestCase):
queryset = OneToOneTarget.objects.all() queryset = OneToOneTarget.objects.all()
serializer = NullableOneToOneTargetSerializer(queryset) serializer = NullableOneToOneTargetSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'target-1', 'nullable_source': {'id': 1, 'name': u'source-1', 'target': 1}}, {'id': 1, 'name': u'target-1', 'nullable_source': {
'id': 1, 'name': u'source-1', 'target': 1}},
{'id': 2, 'name': u'target-2', 'nullable_source': None}, {'id': 2, 'name': u'target-2', 'nullable_source': None},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)

View File

@ -270,7 +270,8 @@ class PKForeignKeyTests(TestCase):
instance = ForeignKeySource.objects.get(pk=1) instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data) serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'target': [u'Value may not be null']}) self.assertEquals(
serializer.errors, {'target': [u'Value may not be null']})
class PKNullableForeignKeyTests(TestCase): class PKNullableForeignKeyTests(TestCase):
@ -280,7 +281,8 @@ class PKNullableForeignKeyTests(TestCase):
for idx in range(1, 4): for idx in range(1, 4):
if idx == 3: if idx == 3:
target = None target = None
source = NullableForeignKeySource(name='source-%d' % idx, target=target) source = NullableForeignKeySource(
name='source-%d' % idx, target=target)
source.save() source.save()
def test_foreign_key_retrieve_with_null(self): def test_foreign_key_retrieve_with_null(self):

View File

@ -49,7 +49,8 @@ class PKForeignKeyTests(TestCase):
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset) serializer = ForeignKeyTargetSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, {'id': 1, 'name': u'target-1', 'sources': [
'source-1', 'source-2', 'source-3']},
{'id': 2, 'name': u'target-2', 'sources': []}, {'id': 2, 'name': u'target-2', 'sources': []},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -77,10 +78,12 @@ class PKForeignKeyTests(TestCase):
instance = ForeignKeySource.objects.get(pk=1) instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data) serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'target': [u'Object with name=123 does not exist.']}) self.assertEquals(serializer.errors, {'target': [
u'Object with name=123 does not exist.']})
def test_reverse_foreign_key_update(self): def test_reverse_foreign_key_update(self):
data = {'id': 2, 'name': u'target-2', 'sources': ['source-1', 'source-3']} data = {'id': 2, 'name': u'target-2', 'sources': [
'source-1', 'source-3']}
instance = ForeignKeyTarget.objects.get(pk=2) instance = ForeignKeyTarget.objects.get(pk=2)
serializer = ForeignKeyTargetSerializer(instance, data=data) serializer = ForeignKeyTargetSerializer(instance, data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
@ -89,7 +92,8 @@ class PKForeignKeyTests(TestCase):
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
new_serializer = ForeignKeyTargetSerializer(queryset) new_serializer = ForeignKeyTargetSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, {'id': 1, 'name': u'target-1', 'sources': [
'source-1', 'source-2', 'source-3']},
{'id': 2, 'name': u'target-2', 'sources': []}, {'id': 2, 'name': u'target-2', 'sources': []},
] ]
self.assertEquals(new_serializer.data, expected) self.assertEquals(new_serializer.data, expected)
@ -102,7 +106,8 @@ class PKForeignKeyTests(TestCase):
serializer = ForeignKeyTargetSerializer(queryset) serializer = ForeignKeyTargetSerializer(queryset)
expected = [ expected = [
{'id': 1, 'name': u'target-1', 'sources': ['source-2']}, {'id': 1, 'name': u'target-1', 'sources': ['source-2']},
{'id': 2, 'name': u'target-2', 'sources': ['source-1', 'source-3']}, {'id': 2, 'name': u'target-2', 'sources': [
'source-1', 'source-3']},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -127,7 +132,8 @@ class PKForeignKeyTests(TestCase):
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
def test_reverse_foreign_key_create(self): def test_reverse_foreign_key_create(self):
data = {'id': 3, 'name': u'target-3', 'sources': ['source-1', 'source-3']} data = {'id': 3, 'name': u'target-3', 'sources': [
'source-1', 'source-3']}
serializer = ForeignKeyTargetSerializer(data=data) serializer = ForeignKeyTargetSerializer(data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
obj = serializer.save() obj = serializer.save()
@ -140,7 +146,8 @@ class PKForeignKeyTests(TestCase):
expected = [ expected = [
{'id': 1, 'name': u'target-1', 'sources': ['source-2']}, {'id': 1, 'name': u'target-1', 'sources': ['source-2']},
{'id': 2, 'name': u'target-2', 'sources': []}, {'id': 2, 'name': u'target-2', 'sources': []},
{'id': 3, 'name': u'target-3', 'sources': ['source-1', 'source-3']}, {'id': 3, 'name': u'target-3', 'sources': [
'source-1', 'source-3']},
] ]
self.assertEquals(serializer.data, expected) self.assertEquals(serializer.data, expected)
@ -149,7 +156,8 @@ class PKForeignKeyTests(TestCase):
instance = ForeignKeySource.objects.get(pk=1) instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data) serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'target': [u'Value may not be null']}) self.assertEquals(
serializer.errors, {'target': [u'Value may not be null']})
class SlugNullableForeignKeyTests(TestCase): class SlugNullableForeignKeyTests(TestCase):
@ -159,7 +167,8 @@ class SlugNullableForeignKeyTests(TestCase):
for idx in range(1, 4): for idx in range(1, 4):
if idx == 3: if idx == 3:
target = None target = None
source = NullableForeignKeySource(name='source-%d' % idx, target=target) source = NullableForeignKeySource(
name='source-%d' % idx, target=target)
source.save() source.save()
def test_foreign_key_retrieve_with_null(self): def test_foreign_key_retrieve_with_null(self):

View File

@ -80,14 +80,19 @@ class HTMLView1(APIView):
return Response('text') return Response('text')
urlpatterns = patterns('', urlpatterns = patterns('',
url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB])), url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[
url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB])), RendererA, RendererB])),
url(r'^$', MockView.as_view(
renderer_classes=[RendererA, RendererB])),
url(r'^cache$', MockGETView.as_view()), url(r'^cache$', MockGETView.as_view()),
url(r'^jsonp/jsonrenderer$', MockGETView.as_view(renderer_classes=[JSONRenderer, JSONPRenderer])), url(r'^jsonp/jsonrenderer$', MockGETView.as_view(renderer_classes=[
url(r'^jsonp/nojsonrenderer$', MockGETView.as_view(renderer_classes=[JSONPRenderer])), JSONRenderer, JSONPRenderer])),
url(r'^jsonp/nojsonrenderer$', MockGETView.as_view(
renderer_classes=[JSONPRenderer])),
url(r'^html$', HTMLView.as_view()), url(r'^html$', HTMLView.as_view()),
url(r'^html1$', HTMLView1.as_view()), url(r'^html1$', HTMLView1.as_view()),
url(r'^api', include('rest_framework.urls', namespace='rest_framework')) url(r'^api', include(
'rest_framework.urls', namespace='rest_framework'))
) )
@ -282,11 +287,13 @@ class JSONPRendererTests(TestCase):
Test JSONP rendering with callback function name. Test JSONP rendering with callback function name.
""" """
callback_func = 'myjsonpcallback' callback_func = 'myjsonpcallback'
resp = self.client.get('/jsonp/nojsonrenderer?callback=' + callback_func, resp = self.client.get(
'/jsonp/nojsonrenderer?callback=' + callback_func,
HTTP_ACCEPT='application/javascript') HTTP_ACCEPT='application/javascript')
self.assertEquals(resp.status_code, status.HTTP_200_OK) self.assertEquals(resp.status_code, status.HTTP_200_OK)
self.assertEquals(resp['Content-Type'], 'application/javascript') self.assertEquals(resp['Content-Type'], 'application/javascript')
self.assertEquals(resp.content, '%s(%s);' % (callback_func, _flat_repr)) self.assertEquals(
resp.content, '%s(%s);' % (callback_func, _flat_repr))
if yaml: if yaml:
@ -380,7 +387,8 @@ class XMLRendererTestCase(TestCase):
Test XML rendering. Test XML rendering.
""" """
renderer = XMLRenderer() renderer = XMLRenderer()
content = renderer.render({'field': Decimal('111.2')}, 'application/xml') content = renderer.render(
{'field': Decimal('111.2')}, 'application/xml')
self.assertXMLContains(content, '<field>111.2</field>') self.assertXMLContains(content, '<field>111.2</field>')
def test_render_none(self): def test_render_none(self):
@ -405,15 +413,18 @@ class XMLRendererTestCase(TestCase):
Test XML rendering. Test XML rendering.
""" """
renderer = XMLRenderer() renderer = XMLRenderer()
content = StringIO(renderer.render(self._complex_data, 'application/xml')) content = StringIO(
renderer.render(self._complex_data, 'application/xml'))
parser = XMLParser() parser = XMLParser()
complex_data_out = parser.parse(content) complex_data_out = parser.parse(content)
error_msg = "complex data differs!IN:\n %s \n\n OUT:\n %s" % (repr(self._complex_data), repr(complex_data_out)) error_msg = "complex data differs!IN:\n %s \n\n OUT:\n %s" % (
repr(self._complex_data), repr(complex_data_out))
self.assertEqual(self._complex_data, complex_data_out, error_msg) self.assertEqual(self._complex_data, complex_data_out, error_msg)
def assertXMLContains(self, xml, string): def assertXMLContains(self, xml, string):
self.assertTrue(xml.startswith('<?xml version="1.0" encoding="utf-8"?>\n<root>')) self.assertTrue(
xml.startswith('<?xml version="1.0" encoding="utf-8"?>\n<root>'))
self.assertTrue(xml.endswith('</root>')) self.assertTrue(xml.endswith('</root>'))
self.assertTrue(string in xml, '%r not in %r' % (string, xml)) self.assertTrue(string in xml, '%r not in %r' % (string, xml))

View File

@ -53,7 +53,8 @@ class TestMethodOverloading(TestCase):
POST requests can be overloaded to another method by setting a POST requests can be overloaded to another method by setting a
reserved form field reserved form field
""" """
request = Request(factory.post('/', {api_settings.FORM_METHOD_OVERRIDE: 'DELETE'})) request = Request(
factory.post('/', {api_settings.FORM_METHOD_OVERRIDE: 'DELETE'}))
self.assertEqual(request.method, 'DELETE') self.assertEqual(request.method, 'DELETE')
@ -88,7 +89,8 @@ class TestContentParsing(TestCase):
""" """
content = 'qwerty' content = 'qwerty'
content_type = 'text/plain' content_type = 'text/plain'
request = Request(factory.post('/', content, content_type=content_type)) request = Request(
factory.post('/', content, content_type=content_type))
request.parsers = (PlainTextParser(),) request.parsers = (PlainTextParser(),)
self.assertEqual(request.DATA, content) self.assertEqual(request.DATA, content)
@ -112,7 +114,8 @@ class TestContentParsing(TestCase):
if VERSION >= (1, 5): if VERSION >= (1, 5):
from django.test.client import MULTIPART_CONTENT, BOUNDARY, encode_multipart from django.test.client import MULTIPART_CONTENT, BOUNDARY, encode_multipart
request = Request(factory.put('/', encode_multipart(BOUNDARY, data), request = Request(
factory.put('/', encode_multipart(BOUNDARY, data),
content_type=MULTIPART_CONTENT)) content_type=MULTIPART_CONTENT))
else: else:
request = Request(factory.put('/', data)) request = Request(factory.put('/', data))
@ -252,7 +255,8 @@ class TestContentParsingWithAuthentication(TestCase):
self.username = 'john' self.username = 'john'
self.email = 'lennon@thebeatles.com' self.email = 'lennon@thebeatles.com'
self.password = 'password' self.password = 'password'
self.user = User.objects.create_user(self.username, self.email, self.password) self.user = User.objects.create_user(
self.username, self.email, self.password)
def test_user_logged_in_authentication_has_POST_when_not_logged_in(self): def test_user_logged_in_authentication_has_POST_when_not_logged_in(self):
""" """
@ -274,10 +278,12 @@ class TestContentParsingWithAuthentication(TestCase):
# content = {'example': 'example'} # content = {'example': 'example'}
# response = self.client.post('/', content) # response = self.client.post('/', content)
# self.assertEqual(status.OK, response.status_code, "POST data is malformed") # self.assertEqual(status.OK, response.status_code, "POST data is
# malformed")
# response = self.csrf_client.post('/', content) # response = self.csrf_client.post('/', content)
# self.assertEqual(status.OK, response.status_code, "POST data is malformed") # self.assertEqual(status.OK, response.status_code, "POST data is
# malformed")
class TestUserSetter(TestCase): class TestUserSetter(TestCase):

View File

@ -64,15 +64,19 @@ class HTMLView1(APIView):
urlpatterns = patterns('', urlpatterns = patterns('',
url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB])), url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[
url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB])), RendererA, RendererB])),
url(r'^$', MockView.as_view(
renderer_classes=[RendererA, RendererB])),
url(r'^html$', HTMLView.as_view()), url(r'^html$', HTMLView.as_view()),
url(r'^html1$', HTMLView1.as_view()), url(r'^html1$', HTMLView1.as_view()),
url(r'^restframework', include('rest_framework.urls', namespace='rest_framework')) url(r'^restframework', include('rest_framework.urls',
namespace='rest_framework'))
) )
# TODO: Clean tests bellow - remove duplicates with above, better unit testing, ... # TODO: Clean tests bellow - remove duplicates with above, better unit
# testing, ...
class RendererIntegrationTests(TestCase): class RendererIntegrationTests(TestCase):
""" """
End-to-end testing of renderers using an ResponseMixin on a generic view. End-to-end testing of renderers using an ResponseMixin on a generic view.

View File

@ -2,7 +2,8 @@ import datetime
import pickle import pickle
from django.test import TestCase from django.test import TestCase
from rest_framework import serializers from rest_framework import serializers
from rest_framework.tests.models import (HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel, from rest_framework.tests.models import (
HasPositiveIntegerAsChoice, Album, ActionItem, Anchor, BasicModel,
BlankFieldModel, BlogPost, Book, CallableDefaultValueModel, DefaultValueModel, BlankFieldModel, BlogPost, Book, CallableDefaultValueModel, DefaultValueModel,
ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo) ManyToManyModel, Person, ReadOnlyManyToManyModel, Photo)
@ -117,7 +118,8 @@ class BasicTests(TestCase):
self.assertEquals(serializer.is_valid(), True) self.assertEquals(serializer.is_valid(), True)
self.assertEquals(serializer.object, expected) self.assertEquals(serializer.object, expected)
self.assertFalse(serializer.object is expected) self.assertFalse(serializer.object is expected)
self.assertEquals(serializer.data['sub_comment'], 'And Merry Christmas!') self.assertEquals(
serializer.data['sub_comment'], 'And Merry Christmas!')
def test_update(self): def test_update(self):
serializer = CommentSerializer(self.comment, data=self.data) serializer = CommentSerializer(self.comment, data=self.data)
@ -125,14 +127,16 @@ class BasicTests(TestCase):
self.assertEquals(serializer.is_valid(), True) self.assertEquals(serializer.is_valid(), True)
self.assertEquals(serializer.object, expected) self.assertEquals(serializer.object, expected)
self.assertTrue(serializer.object is expected) self.assertTrue(serializer.object is expected)
self.assertEquals(serializer.data['sub_comment'], 'And Merry Christmas!') self.assertEquals(
serializer.data['sub_comment'], 'And Merry Christmas!')
def test_partial_update(self): def test_partial_update(self):
msg = 'Merry New Year!' msg = 'Merry New Year!'
partial_data = {'content': msg} partial_data = {'content': msg}
serializer = CommentSerializer(self.comment, data=partial_data) serializer = CommentSerializer(self.comment, data=partial_data)
self.assertEquals(serializer.is_valid(), False) self.assertEquals(serializer.is_valid(), False)
serializer = CommentSerializer(self.comment, data=partial_data, partial=True) serializer = CommentSerializer(
self.comment, data=partial_data, partial=True)
expected = self.comment expected = self.comment
self.assertEqual(serializer.is_valid(), True) self.assertEqual(serializer.is_valid(), True)
self.assertEquals(serializer.object, expected) self.assertEquals(serializer.object, expected)
@ -160,7 +164,8 @@ class BasicTests(TestCase):
""" """
Attempting to update fields set as read_only should have no effect. Attempting to update fields set as read_only should have no effect.
""" """
serializer = PersonSerializer(self.person, data={'name': 'dwight', 'age': 99}) serializer = PersonSerializer(
self.person, data={'name': 'dwight', 'age': 99})
self.assertEquals(serializer.is_valid(), True) self.assertEquals(serializer.is_valid(), True)
instance = serializer.save() instance = serializer.save()
self.assertEquals(serializer.errors, {}) self.assertEquals(serializer.errors, {})
@ -201,7 +206,8 @@ class ValidationTests(TestCase):
} }
serializer = CommentSerializer(self.comment, data=data) serializer = CommentSerializer(self.comment, data=data)
self.assertEquals(serializer.is_valid(), False) self.assertEquals(serializer.is_valid(), False)
self.assertEquals(serializer.errors, {'email': [u'This field is required.']}) self.assertEquals(
serializer.errors, {'email': [u'This field is required.']})
def test_missing_bool_with_default(self): def test_missing_bool_with_default(self):
"""Make sure that a boolean value with a 'False' value is not """Make sure that a boolean value with a 'False' value is not
@ -221,23 +227,27 @@ class ValidationTests(TestCase):
data = ['i am', 'a', 'list'] data = ['i am', 'a', 'list']
serializer = CommentSerializer(self.comment, data=data) serializer = CommentSerializer(self.comment, data=data)
self.assertEquals(serializer.is_valid(), False) self.assertEquals(serializer.is_valid(), False)
self.assertEquals(serializer.errors, {'non_field_errors': [u'Invalid data']}) self.assertEquals(
serializer.errors, {'non_field_errors': [u'Invalid data']})
data = 'and i am a string' data = 'and i am a string'
serializer = CommentSerializer(self.comment, data=data) serializer = CommentSerializer(self.comment, data=data)
self.assertEquals(serializer.is_valid(), False) self.assertEquals(serializer.is_valid(), False)
self.assertEquals(serializer.errors, {'non_field_errors': [u'Invalid data']}) self.assertEquals(
serializer.errors, {'non_field_errors': [u'Invalid data']})
data = 42 data = 42
serializer = CommentSerializer(self.comment, data=data) serializer = CommentSerializer(self.comment, data=data)
self.assertEquals(serializer.is_valid(), False) self.assertEquals(serializer.is_valid(), False)
self.assertEquals(serializer.errors, {'non_field_errors': [u'Invalid data']}) self.assertEquals(
serializer.errors, {'non_field_errors': [u'Invalid data']})
def test_cross_field_validation(self): def test_cross_field_validation(self):
class CommentSerializerWithCrossFieldValidator(CommentSerializer): class CommentSerializerWithCrossFieldValidator(CommentSerializer):
def validate(self, attrs): def validate(self, attrs):
if attrs["email"] not in attrs["content"]: if attrs["email"] not in attrs["content"]:
raise serializers.ValidationError("Email address not in content") raise serializers.ValidationError(
"Email address not in content")
return attrs return attrs
data = { data = {
@ -253,7 +263,8 @@ class ValidationTests(TestCase):
serializer = CommentSerializerWithCrossFieldValidator(data=data) serializer = CommentSerializerWithCrossFieldValidator(data=data)
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'non_field_errors': [u'Email address not in content']}) self.assertEquals(serializer.errors, {'non_field_errors': [
u'Email address not in content']})
def test_null_is_true_fields(self): def test_null_is_true_fields(self):
""" """
@ -308,7 +319,8 @@ class CustomValidationTests(TestCase):
serializer = self.CommentSerializerWithFieldValidator(data=data) serializer = self.CommentSerializerWithFieldValidator(data=data)
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'content': [u'Test not in value']}) self.assertEquals(
serializer.errors, {'content': [u'Test not in value']})
def test_missing_data(self): def test_missing_data(self):
""" """
@ -318,9 +330,11 @@ class CustomValidationTests(TestCase):
'email': 'tom@example.com', 'email': 'tom@example.com',
'created': datetime.datetime(2012, 1, 1) 'created': datetime.datetime(2012, 1, 1)
} }
serializer = self.CommentSerializerWithFieldValidator(data=incomplete_data) serializer = self.CommentSerializerWithFieldValidator(
data=incomplete_data)
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'content': [u'This field is required.']}) self.assertEquals(
serializer.errors, {'content': [u'This field is required.']})
def test_wrong_data(self): def test_wrong_data(self):
""" """
@ -333,7 +347,8 @@ class CustomValidationTests(TestCase):
} }
serializer = self.CommentSerializerWithFieldValidator(data=wrong_data) serializer = self.CommentSerializerWithFieldValidator(data=wrong_data)
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'email': [u'Enter a valid e-mail address.']}) self.assertEquals(
serializer.errors, {'email': [u'Enter a valid e-mail address.']})
class PositiveIntegerAsChoiceTests(TestCase): class PositiveIntegerAsChoiceTests(TestCase):
@ -353,7 +368,8 @@ class ModelValidationTests(TestCase):
serializer.save() serializer.save()
second_serializer = AlbumsSerializer(data={'title': 'a'}) second_serializer = AlbumsSerializer(data={'title': 'a'})
self.assertFalse(second_serializer.is_valid()) self.assertFalse(second_serializer.is_valid())
self.assertEqual(second_serializer.errors, {'title': [u'Album with this Title already exists.']}) self.assertEqual(second_serializer.errors, {'title': [
u'Album with this Title already exists.']})
def test_foreign_key_with_partial(self): def test_foreign_key_with_partial(self):
""" """
@ -369,12 +385,14 @@ class ModelValidationTests(TestCase):
class Meta: class Meta:
model = Photo model = Photo
photo_serializer = PhotoSerializer(data={'description': 'test', 'album': album.pk}) photo_serializer = PhotoSerializer(
data={'description': 'test', 'album': album.pk})
self.assertTrue(photo_serializer.is_valid()) self.assertTrue(photo_serializer.is_valid())
photo = photo_serializer.save() photo = photo_serializer.save()
# Updating only the album (foreign key) # Updating only the album (foreign key)
photo_serializer = PhotoSerializer(instance=photo, data={'album': album.pk}, partial=True) photo_serializer = PhotoSerializer(
instance=photo, data={'album': album.pk}, partial=True)
self.assertTrue(photo_serializer.is_valid()) self.assertTrue(photo_serializer.is_valid())
self.assertTrue(photo_serializer.save()) self.assertTrue(photo_serializer.save())
@ -391,15 +409,18 @@ class RegexValidationTest(TestCase):
def test_create_failed(self): def test_create_failed(self):
serializer = BookSerializer(data={'isbn': '1234567890'}) serializer = BookSerializer(data={'isbn': '1234567890'})
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'isbn': [u'isbn has to be exact 13 numbers']}) self.assertEquals(serializer.errors, {'isbn': [
u'isbn has to be exact 13 numbers']})
serializer = BookSerializer(data={'isbn': '12345678901234'}) serializer = BookSerializer(data={'isbn': '12345678901234'})
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'isbn': [u'isbn has to be exact 13 numbers']}) self.assertEquals(serializer.errors, {'isbn': [
u'isbn has to be exact 13 numbers']})
serializer = BookSerializer(data={'isbn': 'abcdefghijklm'}) serializer = BookSerializer(data={'isbn': 'abcdefghijklm'})
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertEquals(serializer.errors, {'isbn': [u'isbn has to be exact 13 numbers']}) self.assertEquals(serializer.errors, {'isbn': [
u'isbn has to be exact 13 numbers']})
def test_create_success(self): def test_create_success(self):
serializer = BookSerializer(data={'isbn': '1234567890123'}) serializer = BookSerializer(data={'isbn': '1234567890123'})
@ -415,7 +436,8 @@ class MetadataTests(TestCase):
'created': serializers.DateTimeField 'created': serializers.DateTimeField
} }
for field_name, field in expected.items(): for field_name, field in expected.items():
self.assertTrue(isinstance(serializer.data.fields[field_name], field)) self.assertTrue(
isinstance(serializer.data.fields[field_name], field))
class ManyToManyTests(TestCase): class ManyToManyTests(TestCase):
@ -605,7 +627,8 @@ class DefaultValueTests(TestCase):
instance = serializer.save() instance = serializer.save()
data = {'extra': 'extra_value'} data = {'extra': 'extra_value'}
serializer = self.serializer_class(instance=instance, data=data, partial=True) serializer = self.serializer_class(
instance=instance, data=data, partial=True)
self.assertEquals(serializer.is_valid(), True) self.assertEquals(serializer.is_valid(), True)
instance = serializer.save() instance = serializer.save()
@ -674,7 +697,8 @@ class ManyRelatedTests(TestCase):
class BlogPostSerializer(serializers.Serializer): class BlogPostSerializer(serializers.Serializer):
title = serializers.CharField() title = serializers.CharField()
first_comment = BlogPostCommentSerializer(source='get_first_comment') first_comment = BlogPostCommentSerializer(
source='get_first_comment')
serializer = BlogPostSerializer(post) serializer = BlogPostSerializer(post)
@ -926,5 +950,7 @@ class NestedSerializerContextTests(TestCase):
album_collection = AlbumCollection() album_collection = AlbumCollection()
album_collection.albums = [album1, album2] album_collection.albums = [album1, album2]
# This will raise RuntimeError if context doesn't get passed correctly to the nested Serializers # This will raise RuntimeError if context doesn't get passed correctly
AlbumCollectionSerializer(album_collection, context={'context_item': 'album context'}).data # to the nested Serializers
AlbumCollectionSerializer(
album_collection, context={'context_item': 'album context'}).data

View File

@ -63,5 +63,6 @@ class SettingsTestCase(TestCase):
class TestModelsTestCase(SettingsTestCase): class TestModelsTestCase(SettingsTestCase):
def setUp(self, *args, **kwargs): def setUp(self, *args, **kwargs):
installed_apps = tuple(settings.INSTALLED_APPS) + ('rest_framework.tests',) installed_apps = tuple(
settings.INSTALLED_APPS) + ('rest_framework.tests',)
self.settings_manager.set(INSTALLED_APPS=installed_apps) self.settings_manager.set(INSTALLED_APPS=installed_apps)

View File

@ -123,7 +123,8 @@ class ThrottlingTests(TestCase):
""" """
Ensure for minute based throttles. Ensure for minute based throttles.
""" """
self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling, self.ensure_response_header_contains_proper_throttle_field(
MockView_MinuteThrottling,
((0, None), ((0, None),
(0, None), (0, None),
(0, None), (0, None),
@ -135,7 +136,8 @@ class ThrottlingTests(TestCase):
If a client follows the recommended next request rate, If a client follows the recommended next request rate,
the throttling rate should stay constant. the throttling rate should stay constant.
""" """
self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling, self.ensure_response_header_contains_proper_throttle_field(
MockView_MinuteThrottling,
((0, None), ((0, None),
(20, None), (20, None),
(40, None), (40, None),

View File

@ -32,7 +32,8 @@ class FormatSuffixTests(TestCase):
for test_path in test_paths: for test_path in test_paths:
request = factory.get(test_path.path) request = factory.get(test_path.path)
try: try:
callback, callback_args, callback_kwargs = resolver.resolve(request.path_info) callback, callback_args, callback_kwargs = resolver.resolve(
request.path_info)
except: except:
self.fail("Failed to resolve URL: %s" % request.path_info) self.fail("Failed to resolve URL: %s" % request.path_info)
self.assertEquals(callback_args, test_path.args) self.assertEquals(callback_args, test_path.args)
@ -74,6 +75,7 @@ class FormatSuffixTests(TestCase):
test_paths = [ test_paths = [
URLTestPath('/test/path', (), {'foo': 'bar', }), URLTestPath('/test/path', (), {'foo': 'bar', }),
URLTestPath('/test/path.api', (), {'foo': 'bar', 'format': 'api'}), URLTestPath('/test/path.api', (), {'foo': 'bar', 'format': 'api'}),
URLTestPath('/test/path.asdf', (), {'foo': 'bar', 'format': 'asdf'}), URLTestPath(
'/test/path.asdf', (), {'foo': 'bar', 'format': 'asdf'}),
] ]
self._resolve_urlpatterns(urlpatterns, test_paths) self._resolve_urlpatterns(urlpatterns, test_paths)

View File

@ -19,7 +19,8 @@
# view = MockView() # view = MockView()
# content = {'qwerty': 'uiop'} # content = {'qwerty': 'uiop'}
# self.assertEqual(FormResource(view).validate_request(content, None), content) # self.assertEqual(FormResource(view).validate_request(content, None),
# content)
# def test_disabled_form_validator_get_bound_form_returns_none(self): # def test_disabled_form_validator_get_bound_form_returns_none(self):
# """If the view's form attribute is None on then # """If the view's form attribute is None on then
@ -36,7 +37,8 @@
# def test_disabled_model_form_validator_returns_content_unchanged(self): # def test_disabled_model_form_validator_returns_content_unchanged(self):
# """If the view's form is None and does not have a Resource with a model set then # """If the view's form is None and does not have a Resource with a model set then
# ModelFormValidator(view).validate_request(content, None) should just return the content unmodified.""" # ModelFormValidator(view).validate_request(content, None) should just
# return the content unmodified."""
# class DisabledModelFormView(View): # class DisabledModelFormView(View):
# resource = ModelResource # resource = ModelResource
@ -120,14 +122,16 @@
# def validation_failure_raises_response_exception(self, validator): # def validation_failure_raises_response_exception(self, validator):
# """If form validation fails a ResourceException 400 (Bad Request) should be raised.""" # """If form validation fails a ResourceException 400 (Bad Request) should be raised."""
# content = {} # content = {}
# self.assertRaises(ImmediateResponse, validator.validate_request, content, None) # self.assertRaises(ImmediateResponse, validator.validate_request,
# content, None)
# def validation_does_not_allow_extra_fields_by_default(self, validator): # def validation_does_not_allow_extra_fields_by_default(self, validator):
# """If some (otherwise valid) content includes fields that are not in the form then validation should fail. # """If some (otherwise valid) content includes fields that are not in the form then validation should fail.
# It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up # It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up
# broken clients more easily (eg submitting content with a misnamed field)""" # broken clients more easily (eg submitting content with a misnamed field)"""
# content = {'qwerty': 'uiop', 'extra': 'extra'} # content = {'qwerty': 'uiop', 'extra': 'extra'}
# self.assertRaises(ImmediateResponse, validator.validate_request, content, None) # self.assertRaises(ImmediateResponse, validator.validate_request,
# content, None)
# def validation_allows_extra_fields_if_explicitly_set(self, validator): # def validation_allows_extra_fields_if_explicitly_set(self, validator):
# """If we include an allowed_extra_fields paramater on _validate, then allow fields with those names.""" # """If we include an allowed_extra_fields paramater on _validate, then allow fields with those names."""
@ -147,7 +151,8 @@
# def validation_does_not_require_extra_fields_if_explicitly_set(self, validator): # def validation_does_not_require_extra_fields_if_explicitly_set(self, validator):
# """If we include an allowed_extra_fields paramater on _validate, then do not fail if we do not have fields with those names.""" # """If we include an allowed_extra_fields paramater on _validate, then do not fail if we do not have fields with those names."""
# content = {'qwerty': 'uiop'} # content = {'qwerty': 'uiop'}
# self.assertEqual(validator._validate(content, None, allowed_extra_fields=('extra',)), content) # self.assertEqual(validator._validate(content, None,
# allowed_extra_fields=('extra',)), content)
# def validation_failed_due_to_no_content_returns_appropriate_message(self, validator): # def validation_failed_due_to_no_content_returns_appropriate_message(self, validator):
# """If validation fails due to no content, ensure the response contains a single non-field error""" # """If validation fails due to no content, ensure the response contains a single non-field error"""
@ -299,26 +304,30 @@
# def test_property_fields_are_allowed_on_model_forms(self): # def test_property_fields_are_allowed_on_model_forms(self):
# """Validation on ModelForms may include property fields that exist on the Model to be included in the input.""" # """Validation on ModelForms may include property fields that exist on the Model to be included in the input."""
# content = {'qwerty': 'example', 'uiop': 'example', 'read_only': 'read only'} # content = {'qwerty': 'example', 'uiop': 'example', 'read_only': 'read only'}
# self.assertEqual(self.validator.validate_request(content, None), content) # self.assertEqual(self.validator.validate_request(content, None),
# content)
# def test_property_fields_are_not_required_on_model_forms(self): # def test_property_fields_are_not_required_on_model_forms(self):
# """Validation on ModelForms does not require property fields that exist on the Model to be included in the input.""" # """Validation on ModelForms does not require property fields that exist on the Model to be included in the input."""
# content = {'qwerty': 'example', 'uiop': 'example'} # content = {'qwerty': 'example', 'uiop': 'example'}
# self.assertEqual(self.validator.validate_request(content, None), content) # self.assertEqual(self.validator.validate_request(content, None),
# content)
# def test_extra_fields_not_allowed_on_model_forms(self): # def test_extra_fields_not_allowed_on_model_forms(self):
# """If some (otherwise valid) content includes fields that are not in the form then validation should fail. # """If some (otherwise valid) content includes fields that are not in the form then validation should fail.
# It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up # It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up
# broken clients more easily (eg submitting content with a misnamed field)""" # broken clients more easily (eg submitting content with a misnamed field)"""
# content = {'qwerty': 'example', 'uiop': 'example', 'read_only': 'read only', 'extra': 'extra'} # content = {'qwerty': 'example', 'uiop': 'example', 'read_only': 'read only', 'extra': 'extra'}
# self.assertRaises(ImmediateResponse, self.validator.validate_request, content, None) # self.assertRaises(ImmediateResponse, self.validator.validate_request,
# content, None)
# def test_validate_requires_fields_on_model_forms(self): # def test_validate_requires_fields_on_model_forms(self):
# """If some (otherwise valid) content includes fields that are not in the form then validation should fail. # """If some (otherwise valid) content includes fields that are not in the form then validation should fail.
# It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up # It might be okay on normal form submission, but for Web APIs we oughta get strict, as it'll help show up
# broken clients more easily (eg submitting content with a misnamed field)""" # broken clients more easily (eg submitting content with a misnamed field)"""
# content = {'read_only': 'read only'} # content = {'read_only': 'read only'}
# self.assertRaises(ImmediateResponse, self.validator.validate_request, content, None) # self.assertRaises(ImmediateResponse, self.validator.validate_request,
# content, None)
# def test_validate_does_not_require_blankable_fields_on_model_forms(self): # def test_validate_does_not_require_blankable_fields_on_model_forms(self):
# """Test standard ModelForm validation behaviour - fields with blank=True are not required.""" # """Test standard ModelForm validation behaviour - fields with blank=True are not required."""
@ -326,4 +335,5 @@
# self.validator.validate_request(content, None) # self.validator.validate_request(content, None)
# def test_model_form_validator_uses_model_forms(self): # def test_model_form_validator_uses_model_forms(self):
# self.assertTrue(isinstance(self.validator.get_bound_form(), forms.ModelForm)) # self.assertTrue(isinstance(self.validator.get_bound_form(),
# forms.ModelForm))

View File

@ -16,7 +16,8 @@ def apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required):
patterns = apply_suffix_patterns(urlpattern.url_patterns, patterns = apply_suffix_patterns(urlpattern.url_patterns,
suffix_pattern, suffix_pattern,
suffix_required) suffix_required)
ret.append(url(regex, include(patterns, namespace, app_name), kwargs)) ret.append(
url(regex, include(patterns, namespace, app_name), kwargs))
else: else:
# Regular URL pattern # Regular URL pattern

View File

@ -19,5 +19,6 @@ template_name = {'template_name': 'rest_framework/login.html'}
urlpatterns = patterns('django.contrib.auth.views', urlpatterns = patterns('django.contrib.auth.views',
url(r'^login/$', 'login', template_name, name='login'), url(r'^login/$', 'login', template_name, name='login'),
url(r'^logout/$', 'logout', template_name, name='logout'), url(r'^logout/$', 'logout',
template_name, name='logout'),
) )

View File

@ -21,14 +21,16 @@ class XML2Dict(object):
node_tree[k] = v node_tree[k] = v
# Save childrens # Save childrens
for child in node.getchildren(): for child in node.getchildren():
tag, tree = self._namespace_split(child.tag, self._parse_node(child)) tag, tree = self._namespace_split(
child.tag, self._parse_node(child))
if tag not in node_tree: # the first time, so store it in dict if tag not in node_tree: # the first time, so store it in dict
node_tree[tag] = tree node_tree[tag] = tree
continue continue
old = node_tree[tag] old = node_tree[tag]
if not isinstance(old, list): if not isinstance(old, list):
node_tree.pop(tag) node_tree.pop(tag)
node_tree[tag] = [old] # multi times, so change old dict to a list node_tree[tag] = [old]
# multi times, so change old dict to a list
node_tree[tag].append(tree) # add the new one node_tree[tag].append(tree) # add the new one
return node_tree return node_tree
@ -52,7 +54,8 @@ class XML2Dict(object):
def fromstring(self, s): def fromstring(self, s):
"""parse a string""" """parse a string"""
t = ET.fromstring(s) t = ET.fromstring(s)
unused_root_tag, root_tree = self._namespace_split(t.tag, self._parse_node(t)) unused_root_tag, root_tree = self._namespace_split(
t.tag, self._parse_node(t))
return root_tree return root_tree

View File

@ -14,12 +14,14 @@ def get_breadcrumbs(url):
except Exception: except Exception:
pass pass
else: else:
# Check if this is a REST framework view, and if so add it to the breadcrumbs # Check if this is a REST framework view, and if so add it to the
# breadcrumbs
if isinstance(getattr(view, 'cls_instance', None), APIView): if isinstance(getattr(view, 'cls_instance', None), APIView):
# Don't list the same view twice in a row. # Don't list the same view twice in a row.
# Probably an optional trailing slash. # Probably an optional trailing slash.
if not seen or seen[-1] != view: if not seen or seen[-1] != view:
breadcrumbs_list.insert(0, (view.cls_instance.get_name(), prefix + url)) breadcrumbs_list.insert(
0, (view.cls_instance.get_name(), prefix + url))
seen.append(view) seen.append(view)
if url == '': if url == '':
@ -27,10 +29,12 @@ def get_breadcrumbs(url):
return breadcrumbs_list return breadcrumbs_list
elif url.endswith('/'): elif url.endswith('/'):
# Drop trailing slash off the end and continue to try to resolve more breadcrumbs # Drop trailing slash off the end and continue to try to resolve
# more breadcrumbs
return breadcrumbs_recursive(url.rstrip('/'), breadcrumbs_list, prefix, seen) return breadcrumbs_recursive(url.rstrip('/'), breadcrumbs_list, prefix, seen)
# Drop trailing non-slash off the end and continue to try to resolve more breadcrumbs # Drop trailing non-slash off the end and continue to try to resolve
# more breadcrumbs
return breadcrumbs_recursive(url[:url.rfind('/') + 1], breadcrumbs_list, prefix, seen) return breadcrumbs_recursive(url[:url.rfind('/') + 1], breadcrumbs_list, prefix, seen)
prefix = get_script_prefix().rstrip('/') prefix = get_script_prefix().rstrip('/')

View File

@ -36,7 +36,8 @@ def _remove_leading_indent(content):
# unindent the content if needed # unindent the content if needed
if whitespace_counts: if whitespace_counts:
whitespace_pattern = '^' + (' ' * min(whitespace_counts)) whitespace_pattern = '^' + (' ' * min(whitespace_counts))
content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), '', content) content = re.sub(
re.compile(whitespace_pattern, re.MULTILINE), '', content)
content = content.strip('\n') content = content.strip('\n')
return content return content
@ -299,7 +300,8 @@ class APIView(View):
self.permission_denied(request) self.permission_denied(request)
self.check_throttles(request) self.check_throttles(request)
# Perform content negotiation and store the accepted info on the request # Perform content negotiation and store the accepted info on the
# request
neg = self.perform_content_negotiation(request) neg = self.perform_content_negotiation(request)
request.accepted_renderer, request.accepted_media_type = neg request.accepted_renderer, request.accepted_media_type = neg
@ -383,7 +385,8 @@ class APIView(View):
except Exception as exc: except Exception as exc:
response = self.handle_exception(exc) response = self.handle_exception(exc)
self.response = self.finalize_response(request, response, *args, **kwargs) self.response = self.finalize_response(
request, response, *args, **kwargs)
return self.response return self.response
def options(self, request, *args, **kwargs): def options(self, request, *args, **kwargs):