PEP8 and pyflakes fixes

This commit is contained in:
Adam Nelson 2016-10-21 11:20:10 -04:00
parent 28b801bc1b
commit 556cf1205d
25 changed files with 1467 additions and 913 deletions

View File

@ -8,14 +8,20 @@ to return this information in a more standardized way.
""" """
from __future__ import unicode_literals from __future__ import unicode_literals
from collections import OrderedDict
from django.core.exceptions import PermissionDenied from django.core.exceptions import PermissionDenied
from django.http import Http404 from django.http import Http404
from django.utils.encoding import force_text from django.utils.encoding import force_text
from rest_framework import exceptions, serializers from rest_framework import exceptions
from rest_framework.fields import (
BooleanField, CharField, ChoiceField, DateField, DateTimeField,
DecimalField, DictField, EmailField, Field, FileField, FloatField,
ImageField, IntegerField, ListField, MultipleChoiceField, NullBooleanField,
OrderedDict, RegexField, SlugField, TimeField, URLField
)
from rest_framework.relations import ManyRelatedField, RelatedField
from rest_framework.request import clone_request from rest_framework.request import clone_request
from rest_framework.serializers import Serializer
from rest_framework.utils.field_mapping import ClassLookupDict from rest_framework.utils.field_mapping import ClassLookupDict
@ -36,35 +42,37 @@ class SimpleMetadata(BaseMetadata):
for us to base this on. for us to base this on.
""" """
label_lookup = ClassLookupDict({ label_lookup = ClassLookupDict({
serializers.Field: 'field', Field: 'field',
serializers.BooleanField: 'boolean', BooleanField: 'boolean',
serializers.NullBooleanField: 'boolean', NullBooleanField: 'boolean',
serializers.CharField: 'string', CharField: 'string',
serializers.URLField: 'url', URLField: 'url',
serializers.EmailField: 'email', EmailField: 'email',
serializers.RegexField: 'regex', RegexField: 'regex',
serializers.SlugField: 'slug', SlugField: 'slug',
serializers.IntegerField: 'integer', IntegerField: 'integer',
serializers.FloatField: 'float', FloatField: 'float',
serializers.DecimalField: 'decimal', DecimalField: 'decimal',
serializers.DateField: 'date', DateField: 'date',
serializers.DateTimeField: 'datetime', DateTimeField: 'datetime',
serializers.TimeField: 'time', TimeField: 'time',
serializers.ChoiceField: 'choice', ChoiceField: 'choice',
serializers.MultipleChoiceField: 'multiple choice', MultipleChoiceField: 'multiple choice',
serializers.FileField: 'file upload', FileField: 'file upload',
serializers.ImageField: 'image upload', ImageField: 'image upload',
serializers.ListField: 'list', ListField: 'list',
serializers.DictField: 'nested object', DictField: 'nested object',
serializers.Serializer: 'nested object', Serializer: 'nested object',
}) })
def determine_metadata(self, request, view): def determine_metadata(self, request, view):
metadata = OrderedDict() metadata = OrderedDict()
metadata['name'] = view.get_view_name() metadata['name'] = view.get_view_name()
metadata['description'] = view.get_view_description() metadata['description'] = view.get_view_description()
metadata['renders'] = [renderer.media_type for renderer in view.renderer_classes] metadata['renders'] = [renderer.media_type for renderer in
metadata['parses'] = [parser.media_type for parser in view.parser_classes] view.renderer_classes]
metadata['parses'] = [parser.media_type for parser in
view.parser_classes]
if hasattr(view, 'get_serializer'): if hasattr(view, 'get_serializer'):
actions = self.determine_actions(request, view) actions = self.determine_actions(request, view)
if actions: if actions:
@ -90,7 +98,7 @@ class SimpleMetadata(BaseMetadata):
pass pass
else: else:
# If user has appropriate permissions for the view, include # If user has appropriate permissions for the view, include
# appropriate metadata about the fields that should be supplied. # appropriate metadata about the fields that should be supplied
serializer = view.get_serializer() serializer = view.get_serializer()
actions[method] = self.get_serializer_info(serializer) actions[method] = self.get_serializer_info(serializer)
finally: finally:
@ -107,10 +115,9 @@ class SimpleMetadata(BaseMetadata):
# If this is a `ListSerializer` then we want to examine the # If this is a `ListSerializer` then we want to examine the
# underlying child serializer instance instead. # underlying child serializer instance instead.
serializer = serializer.child serializer = serializer.child
return OrderedDict([ return OrderedDict([(field_name, self.get_field_info(field))
(field_name, self.get_field_info(field)) for field_name, field in
for field_name, field in serializer.fields.items() serializer.fields.items()])
])
def get_field_info(self, field): def get_field_info(self, field):
""" """
@ -138,14 +145,12 @@ class SimpleMetadata(BaseMetadata):
field_info['children'] = self.get_serializer_info(field) field_info['children'] = self.get_serializer_info(field)
if (not field_info.get('read_only') and if (not field_info.get('read_only') and
not isinstance(field, (serializers.RelatedField, serializers.ManyRelatedField)) and not isinstance(field, (RelatedField, ManyRelatedField)) and
hasattr(field, 'choices')): hasattr(field, 'choices')):
field_info['choices'] = [ field_info['choices'] = [{'value': choice_value,
{ 'display_name': force_text(
'value': choice_value, choice_name, strings_only=True)}
'display_name': force_text(choice_name, strings_only=True) for choice_value, choice_name in
} field.choices.items()]
for choice_value, choice_name in field.choices.items()
]
return field_info return field_info

View File

@ -7,7 +7,6 @@ from django.http import Http404
from rest_framework.compat import is_authenticated from rest_framework.compat import is_authenticated
SAFE_METHODS = ('GET', 'HEAD', 'OPTIONS') SAFE_METHODS = ('GET', 'HEAD', 'OPTIONS')

View File

@ -9,24 +9,30 @@ REST framework also provides an HTML renderer that renders the browsable API.
from __future__ import unicode_literals from __future__ import unicode_literals
import json import json
from collections import OrderedDict
from django import forms from django import forms
from django.conf import settings from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.core.paginator import Page from django.core.paginator import Page
from django.http.multipartparser import parse_header from django.http.multipartparser import parse_header
from django.template import Template, loader from django.template import Template, loader
from django.test.client import encode_multipart from django.test.client import encode_multipart
from django.utils import six
from rest_framework import VERSION, exceptions, serializers, status from rest_framework import VERSION, exceptions, status
from rest_framework.compat import ( from rest_framework.compat import (
INDENT_SEPARATORS, LONG_SEPARATORS, SHORT_SEPARATORS, coreapi, INDENT_SEPARATORS, LONG_SEPARATORS, SHORT_SEPARATORS, coreapi,
template_render template_render
) )
from rest_framework.exceptions import ParseError from rest_framework.exceptions import ParseError
from rest_framework.fields import (
BooleanField, ChoiceField, DateField, DateTimeField, EmailField, Field,
FileField, FilePathField, FloatField, HiddenField, IntegerField,
MultipleChoiceField, OrderedDict, TimeField, URLField, six
)
from rest_framework.relations import (
ImproperlyConfigured, ManyRelatedField, RelatedField
)
from rest_framework.request import is_form_media_type, override_method from rest_framework.request import is_form_media_type, override_method
from rest_framework.serializers import ListSerializer, Serializer
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.utils import encoders from rest_framework.utils import encoders
from rest_framework.utils.breadcrumbs import get_breadcrumbs from rest_framework.utils.breadcrumbs import get_breadcrumbs
@ -48,7 +54,8 @@ class BaseRenderer(object):
render_style = 'text' render_style = 'text'
def render(self, data, accepted_media_type=None, renderer_context=None): def render(self, data, accepted_media_type=None, renderer_context=None):
raise NotImplementedError('Renderer class requires .render() to be implemented') raise NotImplementedError(
'Renderer class requires .render() to be implemented')
class JSONRenderer(BaseRenderer): class JSONRenderer(BaseRenderer):
@ -64,7 +71,7 @@ class JSONRenderer(BaseRenderer):
# We don't set a charset because JSON is a binary encoding, # We don't set a charset because JSON is a binary encoding,
# that can be encoded as utf-8, utf-16 or utf-32. # that can be encoded as utf-8, utf-16 or utf-32.
# See: http://www.ietf.org/rfc/rfc4627.txt # See: http://www.ietf.org/rfc/rfc4627.txt
# Also: http://lucumr.pocoo.org/2013/7/19/application-mimetypes-and-encodings/ # http://lucumr.pocoo.org/2013/7/19/application-mimetypes-and-encodings/
charset = None charset = None
def get_indent(self, accepted_media_type, renderer_context): def get_indent(self, accepted_media_type, renderer_context):
@ -72,7 +79,8 @@ class JSONRenderer(BaseRenderer):
# If the media type looks like 'application/json; indent=4', # If the media type looks like 'application/json; indent=4',
# then pretty print the result. # then pretty print the result.
# Note that we coerce `indent=0` into `indent=None`. # Note that we coerce `indent=0` into `indent=None`.
base_media_type, params = parse_header(accepted_media_type.encode('ascii')) base_media_type, params = parse_header(
accepted_media_type.encode('ascii'))
try: try:
return zero_as_none(max(min(int(params['indent']), 8), 0)) return zero_as_none(max(min(int(params['indent']), 8), 0))
except (KeyError, ValueError, TypeError): except (KeyError, ValueError, TypeError):
@ -192,7 +200,8 @@ class TemplateHTMLRenderer(BaseRenderer):
elif hasattr(view, 'template_name'): elif hasattr(view, 'template_name'):
return [view.template_name] return [view.template_name]
raise ImproperlyConfigured( raise ImproperlyConfigured(
'Returned a template response with no `template_name` attribute set on either the view or response' 'Returned a template response with no `template_name` attribute '
'set on either the view or response'
) )
def get_exception_template(self, response): def get_exception_template(self, response):
@ -260,88 +269,93 @@ class HTMLFormRenderer(BaseRenderer):
base_template = 'form.html' base_template = 'form.html'
default_style = ClassLookupDict({ default_style = ClassLookupDict({
serializers.Field: { Field: {
'base_template': 'input.html', 'base_template': 'input.html',
'input_type': 'text' 'input_type': 'text'
}, },
serializers.EmailField: { EmailField: {
'base_template': 'input.html', 'base_template': 'input.html',
'input_type': 'email' 'input_type': 'email'
}, },
serializers.URLField: { URLField: {
'base_template': 'input.html', 'base_template': 'input.html',
'input_type': 'url' 'input_type': 'url'
}, },
serializers.IntegerField: { IntegerField: {
'base_template': 'input.html', 'base_template': 'input.html',
'input_type': 'number' 'input_type': 'number'
}, },
serializers.FloatField: { FloatField: {
'base_template': 'input.html', 'base_template': 'input.html',
'input_type': 'number' 'input_type': 'number'
}, },
serializers.DateTimeField: { DateTimeField: {
'base_template': 'input.html', 'base_template': 'input.html',
'input_type': 'datetime-local' 'input_type': 'datetime-local'
}, },
serializers.DateField: { DateField: {
'base_template': 'input.html', 'base_template': 'input.html',
'input_type': 'date' 'input_type': 'date'
}, },
serializers.TimeField: { TimeField: {
'base_template': 'input.html', 'base_template': 'input.html',
'input_type': 'time' 'input_type': 'time'
}, },
serializers.FileField: { FileField: {
'base_template': 'input.html', 'base_template': 'input.html',
'input_type': 'file' 'input_type': 'file'
}, },
serializers.BooleanField: { BooleanField: {
'base_template': 'checkbox.html' 'base_template': 'checkbox.html'
}, },
serializers.ChoiceField: { ChoiceField: {
'base_template': 'select.html', # Also valid: 'radio.html' 'base_template': 'select.html', # Also valid: 'radio.html'
}, },
serializers.MultipleChoiceField: { MultipleChoiceField: {
'base_template': 'select_multiple.html', # Also valid: 'checkbox_multiple.html' 'base_template': 'select_multiple.html',
# Also valid: 'checkbox_multiple.html'
}, },
serializers.RelatedField: { RelatedField: {
'base_template': 'select.html', # Also valid: 'radio.html' 'base_template': 'select.html', # Also valid: 'radio.html'
}, },
serializers.ManyRelatedField: { ManyRelatedField: {
'base_template': 'select_multiple.html', # Also valid: 'checkbox_multiple.html' 'base_template': 'select_multiple.html',
# Also valid: 'checkbox_multiple.html'
}, },
serializers.Serializer: { Serializer: {
'base_template': 'fieldset.html' 'base_template': 'fieldset.html'
}, },
serializers.ListSerializer: { ListSerializer: {
'base_template': 'list_fieldset.html' 'base_template': 'list_fieldset.html'
}, },
serializers.FilePathField: { FilePathField: {
'base_template': 'select.html', 'base_template': 'select.html',
}, },
}) })
def render_field(self, field, parent_style): def render_field(self, field, parent_style):
if isinstance(field._field, serializers.HiddenField): if isinstance(field._field, HiddenField):
return '' return ''
style = dict(self.default_style[field]) style = dict(self.default_style[field])
style.update(field.style) style.update(field.style)
if 'template_pack' not in style: if 'template_pack' not in style:
style['template_pack'] = parent_style.get('template_pack', self.template_pack) style['template_pack'] = parent_style.get('template_pack',
self.template_pack)
style['renderer'] = self style['renderer'] = self
# Get a clone of the field with text-only value representation. # Get a clone of the field with text-only value representation.
field = field.as_form_field() field = field.as_form_field()
if style.get('input_type') == 'datetime-local' and isinstance(field.value, six.text_type): if style.get('input_type') == 'datetime-local' and isinstance(
field.value, six.text_type):
field.value = field.value.rstrip('Z') field.value = field.value.rstrip('Z')
if 'template' in style: if 'template' in style:
template_name = style['template'] template_name = style['template']
else: else:
template_name = style['template_pack'].strip('/') + '/' + style['base_template'] template_name = style['template_pack'].strip('/') + '/' + style[
'base_template']
template = loader.get_template(template_name) template = loader.get_template(template_name)
context = {'field': field, 'style': style} context = {'field': field, 'style': style}
@ -388,7 +402,8 @@ class BrowsableAPIRenderer(BaseRenderer):
renderers = [renderer for renderer in view.renderer_classes renderers = [renderer for renderer in view.renderer_classes
if not issubclass(renderer, BrowsableAPIRenderer)] if not issubclass(renderer, BrowsableAPIRenderer)]
non_template_renderers = [renderer for renderer in renderers non_template_renderers = [renderer for renderer in renderers
if not hasattr(renderer, 'get_template_names')] if
not hasattr(renderer, 'get_template_names')]
if not renderers: if not renderers:
return None return None
@ -410,7 +425,8 @@ class BrowsableAPIRenderer(BaseRenderer):
render_style = getattr(renderer, 'render_style', 'text') render_style = getattr(renderer, 'render_style', 'text')
assert render_style in ['text', 'binary'], 'Expected .render_style ' \ assert render_style in ['text', 'binary'], 'Expected .render_style ' \
'"text" or "binary", but got "%s"' % render_style '"text" or "binary", but ' \
'got "%s"' % render_style
if render_style == 'binary': if render_style == 'binary':
return '[%d bytes of binary content]' % len(content) return '[%d bytes of binary content]' % len(content)
@ -431,7 +447,8 @@ class BrowsableAPIRenderer(BaseRenderer):
return False # Doesn't have permissions return False # Doesn't have permissions
return True return True
def _get_serializer(self, serializer_class, view_instance, request, *args, **kwargs): def _get_serializer(self, serializer_class, view_instance, request, *args,
**kwargs):
kwargs['context'] = { kwargs['context'] = {
'request': request, 'request': request,
'format': self.format, 'format': self.format,
@ -478,10 +495,10 @@ class BrowsableAPIRenderer(BaseRenderer):
has_serializer = getattr(view, 'get_serializer', None) has_serializer = getattr(view, 'get_serializer', None)
has_serializer_class = getattr(view, 'serializer_class', None) has_serializer_class = getattr(view, 'serializer_class', None)
if ( if ((not has_serializer and not has_serializer_class) or
(not has_serializer and not has_serializer_class) or not any(
not any(is_form_media_type(parser.media_type) for parser in view.parser_classes) is_form_media_type(parser.media_type) for parser in
): view.parser_classes)):
return return
if existing_serializer is not None: if existing_serializer is not None:
@ -492,16 +509,21 @@ class BrowsableAPIRenderer(BaseRenderer):
if has_serializer: if has_serializer:
if method in ('PUT', 'PATCH'): if method in ('PUT', 'PATCH'):
serializer = view.get_serializer(instance=instance, **kwargs) serializer = view.get_serializer(instance=instance,
**kwargs)
else: else:
serializer = view.get_serializer(**kwargs) serializer = view.get_serializer(**kwargs)
else: else:
# at this point we must have a serializer_class # at this point we must have a serializer_class
if method in ('PUT', 'PATCH'): if method in ('PUT', 'PATCH'):
serializer = self._get_serializer(view.serializer_class, view, serializer = self._get_serializer(view.serializer_class,
request, instance=instance, **kwargs) view,
request,
instance=instance,
**kwargs)
else: else:
serializer = self._get_serializer(view.serializer_class, view, serializer = self._get_serializer(view.serializer_class,
view,
request, **kwargs) request, **kwargs)
return self.render_form_for_serializer(serializer) return self.render_form_for_serializer(serializer)
@ -569,7 +591,8 @@ class BrowsableAPIRenderer(BaseRenderer):
label='Media type', label='Media type',
choices=choices, choices=choices,
initial=initial, initial=initial,
widget=forms.Select(attrs={'data-override': 'content-type'}) widget=forms.Select(
attrs={'data-override': 'content-type'})
) )
_content = forms.CharField( _content = forms.CharField(
label='Content', label='Content',
@ -583,7 +606,8 @@ class BrowsableAPIRenderer(BaseRenderer):
return view.get_view_name() return view.get_view_name()
def get_description(self, view, status_code): def get_description(self, view, status_code):
if status_code in (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN): if status_code in (
status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN):
return '' return ''
return view.get_view_description(html=True) return view.get_view_description(html=True)
@ -591,7 +615,8 @@ class BrowsableAPIRenderer(BaseRenderer):
return get_breadcrumbs(request.path, request) return get_breadcrumbs(request.path, request)
def get_filter_form(self, data, view, request): def get_filter_form(self, data, view, request):
if not hasattr(view, 'get_queryset') or not hasattr(view, 'filter_backends'): if not hasattr(view, 'get_queryset') or not hasattr(view,
'filter_backends'):
return return
# Infer if this is a list view or not. # Infer if this is a list view or not.
@ -631,9 +656,11 @@ class BrowsableAPIRenderer(BaseRenderer):
renderer = self.get_default_renderer(view) renderer = self.get_default_renderer(view)
raw_data_post_form = self.get_raw_data_form(data, view, 'POST', request) raw_data_post_form = self.get_raw_data_form(data, view, 'POST',
request)
raw_data_put_form = self.get_raw_data_form(data, view, 'PUT', request) raw_data_put_form = self.get_raw_data_form(data, view, 'PUT', request)
raw_data_patch_form = self.get_raw_data_form(data, view, 'PATCH', request) raw_data_patch_form = self.get_raw_data_form(data, view, 'PATCH',
request)
raw_data_put_or_patch_form = raw_data_put_form or raw_data_patch_form raw_data_put_or_patch_form = raw_data_put_form or raw_data_patch_form
response_headers = OrderedDict(sorted(response.items())) response_headers = OrderedDict(sorted(response.items()))
@ -644,19 +671,22 @@ class BrowsableAPIRenderer(BaseRenderer):
renderer_content_type += ' ;%s' % renderer.charset renderer_content_type += ' ;%s' % renderer.charset
response_headers['Content-Type'] = renderer_content_type response_headers['Content-Type'] = renderer_content_type
if getattr(view, 'paginator', None) and view.paginator.display_page_controls: if getattr(view, 'paginator',
None) and view.paginator.display_page_controls:
paginator = view.paginator paginator = view.paginator
else: else:
paginator = None paginator = None
csrf_cookie_name = settings.CSRF_COOKIE_NAME csrf_cookie_name = settings.CSRF_COOKIE_NAME
csrf_header_name = getattr(settings, 'CSRF_HEADER_NAME', 'HTTP_X_CSRFToken') # Fallback for Django 1.8 csrf_header_name = getattr(settings, 'CSRF_HEADER_NAME',
'HTTP_X_CSRFToken') # Fallback Django 1.8
if csrf_header_name.startswith('HTTP_'): if csrf_header_name.startswith('HTTP_'):
csrf_header_name = csrf_header_name[5:] csrf_header_name = csrf_header_name[5:]
csrf_header_name = csrf_header_name.replace('_', '-') csrf_header_name = csrf_header_name.replace('_', '-')
context = { context = {
'content': self.get_content(renderer, data, accepted_media_type, renderer_context), 'content': self.get_content(renderer, data, accepted_media_type,
renderer_context),
'view': view, 'view': view,
'request': request, 'request': request,
'response': response, 'response': response,
@ -667,13 +697,18 @@ class BrowsableAPIRenderer(BaseRenderer):
'paginator': paginator, 'paginator': paginator,
'breadcrumblist': self.get_breadcrumbs(request), 'breadcrumblist': self.get_breadcrumbs(request),
'allowed_methods': view.allowed_methods, 'allowed_methods': view.allowed_methods,
'available_formats': [renderer_cls.format for renderer_cls in view.renderer_classes], 'available_formats': [renderer_cls.format for renderer_cls in
view.renderer_classes],
'response_headers': response_headers, 'response_headers': response_headers,
'put_form': self.get_rendered_html_form(data, view, 'PUT', request), 'put_form': self.get_rendered_html_form(data, view, 'PUT',
'post_form': self.get_rendered_html_form(data, view, 'POST', request), request),
'delete_form': self.get_rendered_html_form(data, view, 'DELETE', request), 'post_form': self.get_rendered_html_form(data, view, 'POST',
'options_form': self.get_rendered_html_form(data, view, 'OPTIONS', request), request),
'delete_form': self.get_rendered_html_form(data, view, 'DELETE',
request),
'options_form': self.get_rendered_html_form(data, view, 'OPTIONS',
request),
'filter_form': self.get_filter_form(data, view, request), 'filter_form': self.get_filter_form(data, view, request),
@ -699,7 +734,8 @@ class BrowsableAPIRenderer(BaseRenderer):
template = loader.get_template(self.template) template = loader.get_template(self.template)
context = self.get_context(data, accepted_media_type, renderer_context) context = self.get_context(data, accepted_media_type, renderer_context)
ret = template_render(template, context, request=renderer_context['request']) ret = template_render(template, context,
request=renderer_context['request'])
# Munge DELETE Response code to allow us to return content # Munge DELETE Response code to allow us to return content
# (Do this *after* we've rendered the template so that we include # (Do this *after* we've rendered the template so that we include
@ -726,8 +762,11 @@ class AdminRenderer(BrowsableAPIRenderer):
if response.status_code == status.HTTP_400_BAD_REQUEST: if response.status_code == status.HTTP_400_BAD_REQUEST:
# Errors still need to display the list or detail information. # Errors still need to display the list or detail information.
# The only way we can get at that is to simulate a GET request. # The only way we can get at that is to simulate a GET request.
self.error_form = self.get_rendered_html_form(data, view, request.method, request) self.error_form = self.get_rendered_html_form(data, view,
self.error_title = {'POST': 'Create', 'PUT': 'Edit'}.get(request.method, 'Errors') request.method,
request)
self.error_title = {'POST': 'Create', 'PUT': 'Edit'}.get(
request.method, 'Errors')
with override_method(view, request, 'GET') as request: with override_method(view, request, 'GET') as request:
response = view.get(request, *view.args, **view.kwargs) response = view.get(request, *view.args, **view.kwargs)
@ -735,10 +774,12 @@ class AdminRenderer(BrowsableAPIRenderer):
template = loader.get_template(self.template) template = loader.get_template(self.template)
context = self.get_context(data, accepted_media_type, renderer_context) context = self.get_context(data, accepted_media_type, renderer_context)
ret = template_render(template, context, request=renderer_context['request']) ret = template_render(template, context,
request=renderer_context['request'])
# Creation and deletion should use redirects in the admin style. # Creation and deletion should use redirects in the admin style.
if (response.status_code == status.HTTP_201_CREATED) and ('Location' in response): if (response.status_code == status.HTTP_201_CREATED) \
and ('Location' in response):
response.status_code = status.HTTP_303_SEE_OTHER response.status_code = status.HTTP_303_SEE_OTHER
response['Location'] = request.build_absolute_uri() response['Location'] = request.build_absolute_uri()
ret = '' ret = ''
@ -818,7 +859,8 @@ class CoreJSONRenderer(BaseRenderer):
format = 'corejson' format = 'corejson'
def __init__(self): def __init__(self):
assert coreapi, 'Using CoreJSONRenderer, but `coreapi` is not installed.' assert coreapi, 'Using CoreJSONRenderer, but `coreapi` is not ' \
'installed.'
def render(self, data, media_type=None, renderer_context=None): def render(self, data, media_type=None, renderer_context=None):
indent = bool(renderer_context.get('indent', 0)) indent = bool(renderer_context.get('indent', 0))

View File

@ -5,35 +5,39 @@ from importlib import import_module
from django.conf import settings from django.conf import settings
from django.contrib.admindocs.views import simplify_regex from django.contrib.admindocs.views import simplify_regex
from django.utils import six
from django.utils.encoding import force_text, smart_text from django.utils.encoding import force_text, smart_text
from rest_framework import exceptions, renderers, serializers from rest_framework import exceptions, renderers
from rest_framework.compat import ( from rest_framework.compat import (
RegexURLPattern, RegexURLResolver, coreapi, uritemplate, urlparse RegexURLPattern, RegexURLResolver, coreapi, uritemplate, urlparse
) )
from rest_framework.fields import (
BooleanField, DecimalField, Field, FileField, FloatField, HiddenField,
IntegerField, MultipleChoiceField, six
)
from rest_framework.relations import ManyRelatedField
from rest_framework.request import clone_request from rest_framework.request import clone_request
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.serializers import ListSerializer, Serializer
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.utils import formatting from rest_framework.utils import formatting
from rest_framework.utils.field_mapping import ClassLookupDict from rest_framework.utils.field_mapping import ClassLookupDict
from rest_framework.utils.model_meta import _get_pk from rest_framework.utils.model_meta import _get_pk
from rest_framework.views import APIView from rest_framework.views import APIView
header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:') header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:')
types_lookup = ClassLookupDict({ types_lookup = ClassLookupDict({
serializers.Field: 'string', Field: 'string',
serializers.IntegerField: 'integer', IntegerField: 'integer',
serializers.FloatField: 'number', FloatField: 'number',
serializers.DecimalField: 'number', DecimalField: 'number',
serializers.BooleanField: 'boolean', BooleanField: 'boolean',
serializers.FileField: 'file', FileField: 'file',
serializers.MultipleChoiceField: 'array', MultipleChoiceField: 'array',
serializers.ManyRelatedField: 'array', ManyRelatedField: 'array',
serializers.Serializer: 'object', Serializer: 'object',
serializers.ListSerializer: 'array' ListSerializer: 'array'
}) })
@ -104,6 +108,7 @@ class EndpointInspector(object):
""" """
A class to determine the available API endpoints that a project exposes. A class to determine the available API endpoints that a project exposes.
""" """
def __init__(self, patterns=None, urlconf=None): def __init__(self, patterns=None, urlconf=None):
if patterns is None: if patterns is None:
if urlconf is None: if urlconf is None:
@ -176,10 +181,8 @@ class EndpointInspector(object):
if hasattr(callback, 'actions'): if hasattr(callback, 'actions'):
return [method.upper() for method in callback.actions.keys()] return [method.upper() for method in callback.actions.keys()]
return [ return [method for method in callback.cls().allowed_methods
method for method in if method not in ('OPTIONS', 'HEAD')]
callback.cls().allowed_methods if method not in ('OPTIONS', 'HEAD')
]
class SchemaGenerator(object): class SchemaGenerator(object):
@ -193,8 +196,9 @@ class SchemaGenerator(object):
} }
endpoint_inspector_cls = EndpointInspector endpoint_inspector_cls = EndpointInspector
# Map the method names we use for viewset actions onto external schema names. # Map the method names we use for viewset actions onto external schema
# These give us names that are more suitable for the external representation. # names. These give us names that are more suitable for the external
# representation.
# Set by 'SCHEMA_COERCE_METHOD_NAMES'. # Set by 'SCHEMA_COERCE_METHOD_NAMES'.
coerce_method_names = None coerce_method_names = None
@ -223,7 +227,8 @@ class SchemaGenerator(object):
Generate a `coreapi.Document` representing the API schema. Generate a `coreapi.Document` representing the API schema.
""" """
if self.endpoints is None: if self.endpoints is None:
inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf) inspector = self.endpoint_inspector_cls(self.patterns,
self.urlconf)
self.endpoints = inspector.get_api_endpoints() self.endpoints = inspector.get_api_endpoints()
links = self.get_links(request) links = self.get_links(request)
@ -358,7 +363,8 @@ class SchemaGenerator(object):
fields += self.get_pagination_fields(path, method, view) fields += self.get_pagination_fields(path, method, view)
fields += self.get_filter_fields(path, method, view) fields += self.get_filter_fields(path, method, view)
if fields and any([field.location in ('form', 'body') for field in fields]): if fields and any(
[field.location in ('form', 'body') for field in fields]):
encoding = self.get_encoding(path, method, view) encoding = self.get_encoding(path, method, view)
else: else:
encoding = None encoding = None
@ -438,7 +444,8 @@ class SchemaGenerator(object):
fields = [] fields = []
for variable in uritemplate.variables(path): for variable in uritemplate.variables(path):
field = coreapi.Field(name=variable, location='path', required=True) field = coreapi.Field(name=variable, location='path',
required=True)
fields.append(field) fields.append(field)
return fields return fields
@ -456,7 +463,7 @@ class SchemaGenerator(object):
serializer = view.get_serializer() serializer = view.get_serializer()
if isinstance(serializer, serializers.ListSerializer): if isinstance(serializer, ListSerializer):
return [ return [
coreapi.Field( coreapi.Field(
name='data', name='data',
@ -466,16 +473,17 @@ class SchemaGenerator(object):
) )
] ]
if not isinstance(serializer, serializers.Serializer): if not isinstance(serializer, Serializer):
return [] return []
fields = [] fields = []
for field in serializer.fields.values(): for field in serializer.fields.values():
if field.read_only or isinstance(field, serializers.HiddenField): if field.read_only or isinstance(field, HiddenField):
continue continue
required = field.required and method != 'PATCH' required = field.required and method != 'PATCH'
description = force_text(field.help_text) if field.help_text else '' description = force_text(
field.help_text) if field.help_text else ''
field = coreapi.Field( field = coreapi.Field(
name=field.field_name, name=field.field_name,
location='form', location='form',
@ -517,27 +525,30 @@ class SchemaGenerator(object):
the schema document. the schema document.
/users/ ("users", "list"), ("users", "create") /users/ ("users", "list"), ("users", "create")
/users/{pk}/ ("users", "read"), ("users", "update"), ("users", "delete") /users/{pk}/ ("users", "read"), ("users", "update"),
/users/enabled/ ("users", "enabled") # custom viewset list action ("users", "delete")
/users/{pk}/star/ ("users", "star") # custom viewset detail action /users/enabled/ ("users", "enabled") # custom viewset list
/users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create") /users/{pk}/star/ ("users", "star") # custom viewset detail
/users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update"), ("users", "groups", "delete") /users/{pk}/groups/ ("users", "groups", "list"),
("users", "groups", "create")
/users/{pk}/groups/{pk}/ ("users", "groups", "read"),
("users", "groups", "update"),
("users", "groups", "delete")
""" """
if hasattr(view, 'action'): if hasattr(view, 'action'):
# Viewsets have explicitly named actions. # Viewsets have explicitly named actions.
action = view.action action = view.action
else: else:
# Views have no associated action, so we determine one from the method. # Views have no associated action, so we determine one from the
# method.
if is_list_view(subpath, method, view): if is_list_view(subpath, method, view):
action = 'list' action = 'list'
else: else:
action = self.default_mapping[method.lower()] action = self.default_mapping[method.lower()]
named_path_components = [ named_path_components = [component for component
component for component in subpath.strip('/').split('/')
in subpath.strip('/').split('/') if '{' not in component]
if '{' not in component
]
if is_custom_action(action): if is_custom_action(action):
# Custom action, eg "/users/{pk}/activate/", "/users/active/" # Custom action, eg "/users/{pk}/activate/", "/users/active/"
@ -562,8 +573,10 @@ def get_schema_view(title=None, url=None, renderer_classes=None):
""" """
generator = SchemaGenerator(title=title, url=url) generator = SchemaGenerator(title=title, url=url)
if renderer_classes is None: if renderer_classes is None:
if renderers.BrowsableAPIRenderer in api_settings.DEFAULT_RENDERER_CLASSES: if renderers.BrowsableAPIRenderer in \
rclasses = [renderers.CoreJSONRenderer, renderers.BrowsableAPIRenderer] api_settings.DEFAULT_RENDERER_CLASSES:
rclasses = [renderers.CoreJSONRenderer,
renderers.BrowsableAPIRenderer]
else: else:
rclasses = [renderers.CoreJSONRenderer] rclasses = [renderers.CoreJSONRenderer]
else: else:

View File

@ -12,7 +12,9 @@ response content is handled by parsers and renderers.
""" """
from __future__ import unicode_literals from __future__ import unicode_literals
import inspect
import traceback import traceback
from copy import deepcopy
from django.db import models from django.db import models
from django.db.models import DurationField as ModelDurationField from django.db.models import DurationField as ModelDurationField
@ -23,6 +25,20 @@ from django.utils.translation import ugettext_lazy as _
from rest_framework.compat import JSONField as ModelJSONField from rest_framework.compat import JSONField as ModelJSONField
from rest_framework.compat import postgres_fields, set_many, unicode_to_repr from rest_framework.compat import postgres_fields, set_many, unicode_to_repr
from rest_framework.fields import (
BooleanField, CharField, ChoiceField, CreateOnlyDefault, DateField,
DateTimeField, DecimalField, DictField, DjangoValidationError,
DurationField, EmailField, ErrorDetail, Field, FileField, FilePathField,
FloatField, HiddenField, ImageField, IntegerField, IPAddressField,
JSONField, ListField, ModelField, NullBooleanField, OrderedDict,
ReadOnlyField, SkipField, SlugField, TimeField, URLField, UUIDField,
ValidationError, api_settings, empty, get_error_detail, html,
representation, set_value, six, timezone
)
from rest_framework.relations import (
HyperlinkedIdentityField, HyperlinkedRelatedField, ImproperlyConfigured,
PKOnlyObject, PrimaryKeyRelatedField, SlugRelatedField
)
from rest_framework.utils import model_meta from rest_framework.utils import model_meta
from rest_framework.utils.field_mapping import ( from rest_framework.utils.field_mapping import (
ClassLookupDict, get_field_kwargs, get_nested_relation_kwargs, ClassLookupDict, get_field_kwargs, get_nested_relation_kwargs,
@ -36,16 +52,6 @@ from rest_framework.validators import (
UniqueTogetherValidator UniqueTogetherValidator
) )
# Note: We do the following so that users of the framework can use this style:
#
# example_field = serializers.CharField(...)
#
# This helps keep the separation between model fields, form fields, and
# serializer fields more explicit.
from rest_framework.fields import * # NOQA # isort:skip
from rest_framework.relations import * # NOQA # isort:skip
# We assume that 'validators' are intended for the child serializer, # We assume that 'validators' are intended for the child serializer,
# rather than the parent serializer. # rather than the parent serializer.
LIST_SERIALIZER_KWARGS = ( LIST_SERIALIZER_KWARGS = (
@ -125,12 +131,11 @@ class BaseSerializer(Field):
} }
if allow_empty is not None: if allow_empty is not None:
list_kwargs['allow_empty'] = allow_empty list_kwargs['allow_empty'] = allow_empty
list_kwargs.update({ list_kwargs.update({key: value for key, value in kwargs.items()
key: value for key, value in kwargs.items() if key in LIST_SERIALIZER_KWARGS})
if key in LIST_SERIALIZER_KWARGS
})
meta = getattr(cls, 'Meta', None) meta = getattr(cls, 'Meta', None)
list_serializer_class = getattr(meta, 'list_serializer_class', ListSerializer) list_serializer_class = getattr(meta, 'list_serializer_class',
ListSerializer)
return list_serializer_class(*args, **list_kwargs) return list_serializer_class(*args, **list_kwargs)
def to_internal_value(self, data): def to_internal_value(self, data):
@ -164,17 +169,17 @@ class BaseSerializer(Field):
# Guard against incorrect use of `serializer.save(commit=False)` # Guard against incorrect use of `serializer.save(commit=False)`
assert 'commit' not in kwargs, ( assert 'commit' not in kwargs, (
"'commit' is not a valid keyword argument to the 'save()' method. " "'commit' is not a valid keyword argument to the 'save()' method. "
"If you need to access data before committing to the database then " "If you need to access data before committing to the database then"
"inspect 'serializer.validated_data' instead. " " inspect 'serializer.validated_data' instead. "
"You can also pass additional keyword arguments to 'save()' if you " "You can also pass additional keyword arguments to 'save()' if you"
"need to set extra attributes on the saved model instance. " " need to set extra attributes on the saved model instance. "
"For example: 'serializer.save(owner=request.user)'.'" "For example: 'serializer.save(owner=request.user)'.'"
) )
assert not hasattr(self, '_data'), ( assert not hasattr(self, '_data'), (
"You cannot call `.save()` after accessing `serializer.data`." "You cannot call `.save()` after accessing `serializer.data`."
"If you need to access data before committing to the database then " "If you need to access data before committing to the database then"
"inspect 'serializer.validated_data' instead. " " inspect 'serializer.validated_data' instead. "
) )
validated_data = dict( validated_data = dict(
@ -224,7 +229,8 @@ class BaseSerializer(Field):
@property @property
def data(self): def data(self):
if hasattr(self, 'initial_data') and not hasattr(self, '_validated_data'): if hasattr(self, 'initial_data') and not hasattr(self,
'_validated_data'):
msg = ( msg = (
'When a serializer is passed a `data` keyword argument you ' 'When a serializer is passed a `data` keyword argument you '
'must call `.is_valid()` before attempting to access the ' 'must call `.is_valid()` before attempting to access the '
@ -235,9 +241,12 @@ class BaseSerializer(Field):
raise AssertionError(msg) raise AssertionError(msg)
if not hasattr(self, '_data'): if not hasattr(self, '_data'):
if self.instance is not None and not getattr(self, '_errors', None): if self.instance is not None and not getattr(self, '_errors',
None):
self._data = self.to_representation(self.instance) self._data = self.to_representation(self.instance)
elif hasattr(self, '_validated_data') and not getattr(self, '_errors', None): elif hasattr(self, '_validated_data') and not getattr(self,
'_errors',
None):
self._data = self.to_representation(self.validated_data) self._data = self.to_representation(self.validated_data)
else: else:
self._data = self.get_initial() self._data = self.get_initial()
@ -253,7 +262,8 @@ class BaseSerializer(Field):
@property @property
def validated_data(self): def validated_data(self):
if not hasattr(self, '_validated_data'): if not hasattr(self, '_validated_data'):
msg = 'You must call `.is_valid()` before accessing `.validated_data`.' msg = 'You must call `.is_valid()` before accessing ' \
'`.validated_data`.'
raise AssertionError(msg) raise AssertionError(msg)
return self._validated_data return self._validated_data
@ -277,9 +287,9 @@ class SerializerMetaclass(type):
if isinstance(obj, Field)] if isinstance(obj, Field)]
fields.sort(key=lambda x: x[1]._creation_counter) fields.sort(key=lambda x: x[1]._creation_counter)
# If this class is subclassing another Serializer, add that Serializer's # If this class is subclassing another Serializer, add that
# fields. Note that we loop over the bases in *reverse*. This is necessary # Serializer's fields. Note that we loop over the bases in *reverse*.
# in order to maintain the correct order of fields. # This is necessary in order to maintain the correct order of fields.
for base in reversed(bases): for base in reversed(bases):
if hasattr(base, '_declared_fields'): if hasattr(base, '_declared_fields'):
fields = list(base._declared_fields.items()) + fields fields = list(base._declared_fields.items()) + fields
@ -302,10 +312,8 @@ def as_serializer_error(exc):
if isinstance(detail, dict): if isinstance(detail, dict):
# If errors may be a dict we use the standard {key: list of values}. # If errors may be a dict we use the standard {key: list of values}.
# Here we ensure that all the values are *lists* of errors. # Here we ensure that all the values are *lists* of errors.
return { return {key: value if isinstance(value, (list, dict)) else [value]
key: value if isinstance(value, (list, dict)) else [value] for key, value in detail.items()}
for key, value in detail.items()
}
elif isinstance(detail, list): elif isinstance(detail, list):
# Errors raised as a list are non-field errors. # Errors raised as a list are non-field errors.
return { return {
@ -320,7 +328,8 @@ def as_serializer_error(exc):
@six.add_metaclass(SerializerMetaclass) @six.add_metaclass(SerializerMetaclass)
class Serializer(BaseSerializer): class Serializer(BaseSerializer):
default_error_messages = { default_error_messages = {
'invalid': _('Invalid data. Expected a dictionary, but got {datatype}.') 'invalid': _(
'Invalid data. Expected a dictionary, but got {datatype}.')
} }
@property @property
@ -339,17 +348,13 @@ class Serializer(BaseSerializer):
@cached_property @cached_property
def _writable_fields(self): def _writable_fields(self):
return [ return [field for field in self.fields.values()
field for field in self.fields.values() if (not field.read_only) or (field.default is not empty)]
if (not field.read_only) or (field.default is not empty)
]
@cached_property @cached_property
def _readable_fields(self): def _readable_fields(self):
return [ return [field for field in self.fields.values()
field for field in self.fields.values() if not field.write_only]
if not field.write_only
]
def get_fields(self): def get_fields(self):
""" """
@ -358,7 +363,7 @@ class Serializer(BaseSerializer):
# Every new serializer is created with a clone of the field instances. # Every new serializer is created with a clone of the field instances.
# This allows users to dynamically modify the fields on a serializer # This allows users to dynamically modify the fields on a serializer
# instance without affecting every other serializer class. # instance without affecting every other serializer class.
return copy.deepcopy(self._declared_fields) return deepcopy(self._declared_fields)
def get_validators(self): def get_validators(self):
""" """
@ -371,24 +376,22 @@ class Serializer(BaseSerializer):
def get_initial(self): def get_initial(self):
if hasattr(self, 'initial_data'): if hasattr(self, 'initial_data'):
return OrderedDict([ return OrderedDict([(field_name,
(field_name, field.get_value(self.initial_data)) field.get_value(self.initial_data))
for field_name, field in self.fields.items() for field_name, field in self.fields.items()
if (field.get_value(self.initial_data) is not empty) and if (field.get_value(self.initial_data)
not field.read_only is not empty) and not field.read_only])
])
return OrderedDict([ return OrderedDict([(field.field_name, field.get_initial())
(field.field_name, field.get_initial()) for field in self.fields.values()
for field in self.fields.values() if not field.read_only])
if not field.read_only
])
def get_value(self, dictionary): def get_value(self, dictionary):
# We override the default field access in order to support # We override the default field access in order to support
# nested HTML forms. # nested HTML forms.
if html.is_html_input(dictionary): if html.is_html_input(dictionary):
return html.parse_html_dict(dictionary, prefix=self.field_name) or empty return html.parse_html_dict(dictionary,
prefix=self.field_name) or empty
return dictionary.get(self.field_name, empty) return dictionary.get(self.field_name, empty)
def run_validation(self, data=empty): def run_validation(self, data=empty):
@ -405,7 +408,8 @@ class Serializer(BaseSerializer):
try: try:
self.run_validators(value) self.run_validators(value)
value = self.validate(value) value = self.validate(value)
assert value is not None, '.validate() should return the validated data' assert value is not None, '.validate() should return the ' \
'validated data'
except (ValidationError, DjangoValidationError) as exc: except (ValidationError, DjangoValidationError) as exc:
raise ValidationError(detail=as_serializer_error(exc)) raise ValidationError(detail=as_serializer_error(exc))
@ -428,7 +432,8 @@ class Serializer(BaseSerializer):
fields = self._writable_fields fields = self._writable_fields
for field in fields: for field in fields:
validate_method = getattr(self, 'validate_' + field.field_name, None) validate_method = getattr(self, 'validate_' + field.field_name,
None)
primitive_value = field.get_value(data) primitive_value = field.get_value(data)
try: try:
validated_value = field.run_validation(primitive_value) validated_value = field.run_validation(primitive_value)
@ -466,7 +471,8 @@ class Serializer(BaseSerializer):
# #
# For related fields with `use_pk_only_optimization` we need to # For related fields with `use_pk_only_optimization` we need to
# resolve the pk value. # resolve the pk value.
check_for_none = attribute.pk if isinstance(attribute, PKOnlyObject) else attribute check_for_none = attribute.pk if \
isinstance(attribute, PKOnlyObject) else attribute
if check_for_none is None: if check_for_none is None:
ret[field.field_name] = None ret[field.field_name] = None
else: else:
@ -523,15 +529,17 @@ class ListSerializer(BaseSerializer):
many = True many = True
default_error_messages = { default_error_messages = {
'not_a_list': _('Expected a list of items but got type "{input_type}".'), 'not_a_list': _(
'Expected a list of items but got type "{input_type}".'),
'empty': _('This list may not be empty.') 'empty': _('This list may not be empty.')
} }
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.child = kwargs.pop('child', copy.deepcopy(self.child)) self.child = kwargs.pop('child', deepcopy(self.child))
self.allow_empty = kwargs.pop('allow_empty', True) self.allow_empty = kwargs.pop('allow_empty', True)
assert self.child is not None, '`child` is a required argument.' assert self.child is not None, '`child` is a required argument.'
assert not inspect.isclass(self.child), '`child` has not been instantiated.' assert not inspect.isclass(
self.child), '`child` has not been instantiated.'
super(ListSerializer, self).__init__(*args, **kwargs) super(ListSerializer, self).__init__(*args, **kwargs)
self.child.bind(field_name='', parent=self) self.child.bind(field_name='', parent=self)
@ -564,7 +572,8 @@ class ListSerializer(BaseSerializer):
try: try:
self.run_validators(value) self.run_validators(value)
value = self.validate(value) value = self.validate(value)
assert value is not None, '.validate() should return the validated data' assert value is not None, '.validate() should return the ' \
'validated data'
except (ValidationError, DjangoValidationError) as exc: except (ValidationError, DjangoValidationError) as exc:
raise ValidationError(detail=as_serializer_error(exc)) raise ValidationError(detail=as_serializer_error(exc))
@ -616,9 +625,7 @@ class ListSerializer(BaseSerializer):
# so, first get a queryset from the Manager if needed # so, first get a queryset from the Manager if needed
iterable = data.all() if isinstance(data, models.Manager) else data iterable = data.all() if isinstance(data, models.Manager) else data
return [ return [self.child.to_representation(item) for item in iterable]
self.child.to_representation(item) for item in iterable
]
def validate(self, attrs): def validate(self, attrs):
return attrs return attrs
@ -633,9 +640,7 @@ class ListSerializer(BaseSerializer):
) )
def create(self, validated_data): def create(self, validated_data):
return [ return [self.child.create(attrs) for attrs in validated_data]
self.child.create(attrs) for attrs in validated_data
]
def save(self, **kwargs): def save(self, **kwargs):
""" """
@ -644,17 +649,15 @@ class ListSerializer(BaseSerializer):
# Guard against incorrect use of `serializer.save(commit=False)` # Guard against incorrect use of `serializer.save(commit=False)`
assert 'commit' not in kwargs, ( assert 'commit' not in kwargs, (
"'commit' is not a valid keyword argument to the 'save()' method. " "'commit' is not a valid keyword argument to the 'save()' method. "
"If you need to access data before committing to the database then " "If you need to access data before committing to the database then"
"inspect 'serializer.validated_data' instead. " " inspect 'serializer.validated_data' instead. "
"You can also pass additional keyword arguments to 'save()' if you " "You can also pass additional keyword arguments to 'save()' if you"
"need to set extra attributes on the saved model instance. " " need to set extra attributes on the saved model instance. "
"For example: 'serializer.save(owner=request.user)'.'" "For example: 'serializer.save(owner=request.user)'.'"
) )
validated_data = [ validated_data = [dict(list(attrs.items()) + list(kwargs.items()))
dict(list(attrs.items()) + list(kwargs.items())) for attrs in self.validated_data]
for attrs in self.validated_data
]
if self.instance is not None: if self.instance is not None:
self.instance = self.update(self.instance, validated_data) self.instance = self.update(self.instance, validated_data)
@ -769,8 +772,8 @@ def raise_errors_on_nested_writes(method_name, serializer, validated_data):
isinstance(validated_data[key], (list, dict)) isinstance(validated_data[key], (list, dict))
for key, field in serializer.fields.items() for key, field in serializer.fields.items()
), ( ), (
'The `.{method_name}()` method does not support writable dotted-source ' 'The `.{method_name}()` method does not support writable dotted-source'
'fields by default.\nWrite an explicit `.{method_name}()` method for ' ' fields by default.\nWrite an explicit `.{method_name}()` method for '
'serializer `{module}.{class_name}`, or set `read_only=True` on ' 'serializer `{module}.{class_name}`, or set `read_only=True` on '
'dotted-source serializer fields.'.format( 'dotted-source serializer fields.'.format(
method_name=method_name, method_name=method_name,
@ -944,7 +947,7 @@ class ModelSerializer(Serializer):
'Cannot use ModelSerializer with Abstract Models.' 'Cannot use ModelSerializer with Abstract Models.'
) )
declared_fields = copy.deepcopy(self._declared_fields) declared_fields = deepcopy(self._declared_fields)
model = getattr(self.Meta, 'model') model = getattr(self.Meta, 'model')
depth = getattr(self.Meta, 'depth', 0) depth = getattr(self.Meta, 'depth', 0)
@ -1003,7 +1006,8 @@ class ModelSerializer(Serializer):
fields = getattr(self.Meta, 'fields', None) fields = getattr(self.Meta, 'fields', None)
exclude = getattr(self.Meta, 'exclude', None) exclude = getattr(self.Meta, 'exclude', None)
if fields and fields != ALL_FIELDS and not isinstance(fields, (list, tuple)): if fields and fields != ALL_FIELDS and not isinstance(fields,
(list, tuple)):
raise TypeError( raise TypeError(
'The `fields` option must be a list or tuple or "__all__". ' 'The `fields` option must be a list or tuple or "__all__". '
'Got %s.' % type(fields).__name__ 'Got %s.' % type(fields).__name__
@ -1043,7 +1047,8 @@ class ModelSerializer(Serializer):
# a subset of fields. # a subset of fields.
required_field_names = set(declared_fields) required_field_names = set(declared_fields)
for cls in self.__class__.__bases__: for cls in self.__class__.__bases__:
required_field_names -= set(getattr(cls, '_declared_fields', [])) required_field_names -= set(
getattr(cls, '_declared_fields', []))
for field_name in required_field_names: for field_name in required_field_names:
assert field_name in fields, ( assert field_name in fields, (
@ -1101,7 +1106,8 @@ class ModelSerializer(Serializer):
if not nested_depth: if not nested_depth:
return self.build_relational_field(field_name, relation_info) return self.build_relational_field(field_name, relation_info)
else: else:
return self.build_nested_field(field_name, relation_info, nested_depth) return self.build_nested_field(field_name, relation_info,
nested_depth)
elif hasattr(model_class, field_name): elif hasattr(model_class, field_name):
return self.build_property_field(field_name, model_class) return self.build_property_field(field_name, model_class)
@ -1126,7 +1132,8 @@ class ModelSerializer(Serializer):
field_class = self.serializer_choice_field field_class = self.serializer_choice_field
# Some model fields may introduce kwargs that would not be valid # Some model fields may introduce kwargs that would not be valid
# for the choice field. We need to strip these out. # for the choice field. We need to strip these out.
# Eg. models.DecimalField(max_digits=3, decimal_places=1, choices=DECIMAL_CHOICES) # Eg. models.DecimalField(max_digits=3, decimal_places=1, \
# choices=DECIMAL_CHOICES)
valid_kwargs = set(( valid_kwargs = set((
'read_only', 'write_only', 'read_only', 'write_only',
'required', 'default', 'initial', 'source', 'required', 'default', 'initial', 'source',
@ -1144,11 +1151,13 @@ class ModelSerializer(Serializer):
# matched to the model field. # matched to the model field.
field_kwargs.pop('model_field', None) field_kwargs.pop('model_field', None)
if not issubclass(field_class, CharField) and not issubclass(field_class, ChoiceField): if not issubclass(field_class, CharField) and not issubclass(
field_class, ChoiceField):
# `allow_blank` is only valid for textual fields. # `allow_blank` is only valid for textual fields.
field_kwargs.pop('allow_blank', None) field_kwargs.pop('allow_blank', None)
if postgres_fields and isinstance(model_field, postgres_fields.ArrayField): if postgres_fields and isinstance(model_field,
postgres_fields.ArrayField):
# Populate the `child` argument on `ListField` instances generated # Populate the `child` argument on `ListField` instances generated
# for the PostgrSQL specfic `ArrayField`. # for the PostgrSQL specfic `ArrayField`.
child_model_field = model_field.base_field child_model_field = model_field.base_field
@ -1167,7 +1176,9 @@ class ModelSerializer(Serializer):
field_kwargs = get_relation_kwargs(field_name, relation_info) field_kwargs = get_relation_kwargs(field_name, relation_info)
to_field = field_kwargs.pop('to_field', None) to_field = field_kwargs.pop('to_field', None)
if to_field and not relation_info.reverse and not relation_info.related_model._meta.get_field(to_field).primary_key: if to_field and not relation_info.reverse and not \
relation_info.related_model._meta.get_field(
to_field).primary_key:
field_kwargs['slug_field'] = to_field field_kwargs['slug_field'] = to_field
field_class = self.serializer_related_to_field field_class = self.serializer_related_to_field
@ -1181,6 +1192,7 @@ class ModelSerializer(Serializer):
""" """
Create nested fields for forward and reverse relationships. Create nested fields for forward and reverse relationships.
""" """
class NestedSerializer(ModelSerializer): class NestedSerializer(ModelSerializer):
class Meta: class Meta:
model = relation_info.related_model model = relation_info.related_model
@ -1236,7 +1248,8 @@ class ModelSerializer(Serializer):
kwargs.pop('required') kwargs.pop('required')
if extra_kwargs.get('read_only', kwargs.get('read_only', False)): if extra_kwargs.get('read_only', kwargs.get('read_only', False)):
extra_kwargs.pop('required', None) # Read only fields should always omit the 'required' argument. # Read only fields should always omit the 'required' argument.
extra_kwargs.pop('required', None)
kwargs.update(extra_kwargs) kwargs.update(extra_kwargs)
@ -1249,7 +1262,7 @@ class ModelSerializer(Serializer):
Return a dictionary mapping field names to a dictionary of Return a dictionary mapping field names to a dictionary of
additional keyword arguments. additional keyword arguments.
""" """
extra_kwargs = copy.deepcopy(getattr(self.Meta, 'extra_kwargs', {})) extra_kwargs = deepcopy(getattr(self.Meta, 'extra_kwargs', {}))
read_only_fields = getattr(self.Meta, 'read_only_fields', None) read_only_fields = getattr(self.Meta, 'read_only_fields', None)
if read_only_fields is not None: if read_only_fields is not None:
@ -1265,7 +1278,8 @@ class ModelSerializer(Serializer):
return extra_kwargs return extra_kwargs
def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs): def get_uniqueness_extra_kwargs(self, field_names, declared_fields,
extra_kwargs):
""" """
Return any additional field options that need to be included as a Return any additional field options that need to be included as a
result of uniqueness constraints on the model. This is returned as result of uniqueness constraints on the model. This is returned as
@ -1288,7 +1302,8 @@ class ModelSerializer(Serializer):
for model_field in model_fields.values(): for model_field in model_fields.values():
# Include each of the `unique_for_*` field names. # Include each of the `unique_for_*` field names.
unique_constraint_names |= {model_field.unique_for_date, model_field.unique_for_month, unique_constraint_names |= {model_field.unique_for_date,
model_field.unique_for_month,
model_field.unique_for_year} model_field.unique_for_year}
unique_constraint_names -= {None} unique_constraint_names -= {None}
@ -1302,13 +1317,15 @@ class ModelSerializer(Serializer):
# Now we have all the field names that have uniqueness constraints # Now we have all the field names that have uniqueness constraints
# applied, we can add the extra 'required=...' or 'default=...' # applied, we can add the extra 'required=...' or 'default=...'
# arguments that are appropriate to these fields, or add a `HiddenField` for it. # arguments that are appropriate to these fields, or add a
# `HiddenField` for it.
hidden_fields = {} hidden_fields = {}
uniqueness_extra_kwargs = {} uniqueness_extra_kwargs = {}
for unique_constraint_name in unique_constraint_names: for unique_constraint_name in unique_constraint_names:
# Get the model field that is referred too. # Get the model field that is referred too.
unique_constraint_field = model._meta.get_field(unique_constraint_name) unique_constraint_field = model._meta.get_field(
unique_constraint_name)
if getattr(unique_constraint_field, 'auto_now_add', None): if getattr(unique_constraint_field, 'auto_now_add', None):
default = CreateOnlyDefault(timezone.now) default = CreateOnlyDefault(timezone.now)
@ -1322,14 +1339,17 @@ class ModelSerializer(Serializer):
if unique_constraint_name in model_fields: if unique_constraint_name in model_fields:
# The corresponding field is present in the serializer # The corresponding field is present in the serializer
if default is empty: if default is empty:
uniqueness_extra_kwargs[unique_constraint_name] = {'required': True} uniqueness_extra_kwargs[unique_constraint_name] = {
'required': True}
else: else:
uniqueness_extra_kwargs[unique_constraint_name] = {'default': default} uniqueness_extra_kwargs[unique_constraint_name] = {
'default': default}
elif default is not empty: elif default is not empty:
# The corresponding field is not present in the # The corresponding field is not present in the
# serializer. We have a default to use for it, so # serializer. We have a default to use for it, so
# add in a hidden field that populates it. # add in a hidden field that populates it.
hidden_fields[unique_constraint_name] = HiddenField(default=default) hidden_fields[unique_constraint_name] = HiddenField(
default=default)
# Update `extra_kwargs` with any new options. # Update `extra_kwargs` with any new options.
for key, value in uniqueness_extra_kwargs.items(): for key, value in uniqueness_extra_kwargs.items():
@ -1393,7 +1413,8 @@ class ModelSerializer(Serializer):
def get_unique_together_validators(self): def get_unique_together_validators(self):
""" """
Determine a default set of validators for any unique_together constraints. Determine a default set of validators for any unique_together
constraints.
""" """
model_class_inheritance_tree = ( model_class_inheritance_tree = (
[self.Meta.model] + [self.Meta.model] +
@ -1404,10 +1425,8 @@ class ModelSerializer(Serializer):
# which may map onto a model field. Any dotted field name lookups # which may map onto a model field. Any dotted field name lookups
# cannot map to a field, and must be a traversal, so we're not # cannot map to a field, and must be a traversal, so we're not
# including those. # including those.
field_names = { field_names = {field.source for field in self._writable_fields
field.source for field in self._writable_fields if (field.source != '*') and ('.' not in field.source)}
if (field.source != '*') and ('.' not in field.source)
}
# Note that we make sure to check `unique_together` both on the # Note that we make sure to check `unique_together` both on the
# base model class, but also on any parent classes. # base model class, but also on any parent classes.
@ -1469,14 +1488,17 @@ if hasattr(models, 'UUIDField'):
# IPAddressField is deprecated in Django # IPAddressField is deprecated in Django
if hasattr(models, 'IPAddressField'): if hasattr(models, 'IPAddressField'):
ModelSerializer.serializer_field_mapping[models.IPAddressField] = IPAddressField ModelSerializer.serializer_field_mapping[
models.IPAddressField] = IPAddressField
if postgres_fields: if postgres_fields:
class CharMappingField(DictField): class CharMappingField(DictField):
child = CharField(allow_blank=True) child = CharField(allow_blank=True)
ModelSerializer.serializer_field_mapping[postgres_fields.HStoreField] = CharMappingField ModelSerializer.serializer_field_mapping[postgres_fields.HStoreField] = \
ModelSerializer.serializer_field_mapping[postgres_fields.ArrayField] = ListField CharMappingField
ModelSerializer.serializer_field_mapping[postgres_fields.ArrayField] = \
ListField
class HyperlinkedModelSerializer(ModelSerializer): class HyperlinkedModelSerializer(ModelSerializer):
@ -1505,6 +1527,7 @@ class HyperlinkedModelSerializer(ModelSerializer):
""" """
Create nested fields for forward and reverse relationships. Create nested fields for forward and reverse relationships.
""" """
class NestedSerializer(HyperlinkedModelSerializer): class NestedSerializer(HyperlinkedModelSerializer):
class Meta: class Meta:
model = relation_info.related_model model = relation_info.related_model

View File

@ -33,10 +33,8 @@ class TestManyPostView(TestCase):
for item in items: for item in items:
BasicModel(text=item).save() BasicModel(text=item).save()
self.objects = BasicModel.objects self.objects = BasicModel.objects
self.data = [ self.data = [{'id': obj.id, 'text': obj.text}
{'id': obj.id, 'text': obj.text} for obj in self.objects.all()]
for obj in self.objects.all()
]
self.view = ManyPostView.as_view() self.view = ManyPostView.as_view()
def test_post_many_post_view(self): def test_post_many_post_view(self):
@ -44,7 +42,8 @@ class TestManyPostView(TestCase):
POST request to a view that returns a list of objects should POST request to a view that returns a list of objects should
still successfully return the browsable API with a rendered form. still successfully return the browsable API with a rendered form.
Regression test for https://github.com/tomchristie/django-rest-framework/pull/3164 Regression test for
https://github.com/tomchristie/django-rest-framework/pull/3164
""" """
data = {} data = {}
request = factory.post('/', data, format='json') request = factory.post('/', data, format='json')

View File

@ -45,6 +45,7 @@ class NonAtomicAPIExceptionView(APIView):
BasicModel.objects.all() BasicModel.objects.all()
raise Http404 raise Http404
urlpatterns = ( urlpatterns = (
url(r'^$', NonAtomicAPIExceptionView.as_view()), url(r'^$', NonAtomicAPIExceptionView.as_view()),
) )
@ -89,7 +90,8 @@ class DBTransactionErrorTests(TestCase):
Transaction is eventually managed by outer-most transaction atomic Transaction is eventually managed by outer-most transaction atomic
block. DRF do not try to interfere here. block. DRF do not try to interfere here.
We let django deal with the transaction when it will catch the Exception. We let django deal with the transaction when it will catch the
Exception.
""" """
request = factory.post('/') request = factory.post('/')
with self.assertNumQueries(3): with self.assertNumQueries(3):

File diff suppressed because it is too large Load Diff

View File

@ -85,10 +85,8 @@ class TestRootView(TestCase):
for item in items: for item in items:
BasicModel(text=item).save() BasicModel(text=item).save()
self.objects = BasicModel.objects self.objects = BasicModel.objects
self.data = [ self.data = [{'id': obj.id, 'text': obj.text}
{'id': obj.id, 'text': obj.text} for obj in self.objects.all()]
for obj in self.objects.all()
]
self.view = RootView.as_view() self.view = RootView.as_view()
def test_get_root_view(self): def test_get_root_view(self):
@ -122,8 +120,10 @@ class TestRootView(TestCase):
request = factory.put('/', data, format='json') request = factory.put('/', data, format='json')
with self.assertNumQueries(0): with self.assertNumQueries(0):
response = self.view(request).render() response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) self.assertEqual(response.status_code,
self.assertEqual(response.data, {"detail": 'Method "PUT" not allowed.'}) status.HTTP_405_METHOD_NOT_ALLOWED)
self.assertEqual(response.data,
{"detail": 'Method "PUT" not allowed.'})
def test_delete_root_view(self): def test_delete_root_view(self):
""" """
@ -132,8 +132,10 @@ class TestRootView(TestCase):
request = factory.delete('/') request = factory.delete('/')
with self.assertNumQueries(0): with self.assertNumQueries(0):
response = self.view(request).render() response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) self.assertEqual(response.status_code,
self.assertEqual(response.data, {"detail": 'Method "DELETE" not allowed.'}) status.HTTP_405_METHOD_NOT_ALLOWED)
self.assertEqual(response.data,
{"detail": 'Method "DELETE" not allowed.'})
def test_post_cannot_set_id(self): def test_post_cannot_set_id(self):
""" """
@ -156,7 +158,8 @@ class TestRootView(TestCase):
request = factory.post('/', data, HTTP_ACCEPT='text/html') request = factory.post('/', data, HTTP_ACCEPT='text/html')
response = self.view(request).render() response = self.view(request).render()
expected_error = '<span class="help-block">Ensure this field has no more than 100 characters.</span>' expected_error = '<span class="help-block">Ensure this field has no more than 100 characters.</span>'
self.assertIn(expected_error, response.rendered_content.decode('utf-8')) self.assertIn(expected_error,
response.rendered_content.decode('utf-8'))
EXPECTED_QUERIES_FOR_PUT = 2 EXPECTED_QUERIES_FOR_PUT = 2
@ -171,10 +174,8 @@ class TestInstanceView(TestCase):
for item in items: for item in items:
BasicModel(text=item).save() BasicModel(text=item).save()
self.objects = BasicModel.objects.exclude(text='filtered out') self.objects = BasicModel.objects.exclude(text='filtered out')
self.data = [ self.data = [{'id': obj.id, 'text': obj.text}
{'id': obj.id, 'text': obj.text} for obj in self.objects.all()]
for obj in self.objects.all()
]
self.view = InstanceView.as_view() self.view = InstanceView.as_view()
self.slug_based_view = SlugBasedInstanceView.as_view() self.slug_based_view = SlugBasedInstanceView.as_view()
@ -196,8 +197,10 @@ class TestInstanceView(TestCase):
request = factory.post('/', data, format='json') request = factory.post('/', data, format='json')
with self.assertNumQueries(0): with self.assertNumQueries(0):
response = self.view(request).render() response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) self.assertEqual(response.status_code,
self.assertEqual(response.data, {"detail": 'Method "POST" not allowed.'}) status.HTTP_405_METHOD_NOT_ALLOWED)
self.assertEqual(response.data,
{"detail": 'Method "POST" not allowed.'})
def test_put_instance_view(self): def test_put_instance_view(self):
""" """
@ -280,7 +283,8 @@ class TestInstanceView(TestCase):
""" """
data = {'text': 'foo'} data = {'text': 'foo'}
filtered_out_pk = BasicModel.objects.filter(text='filtered out')[0].pk filtered_out_pk = BasicModel.objects.filter(text='filtered out')[0].pk
request = factory.put('/{0}'.format(filtered_out_pk), data, format='json') request = factory.put('/{0}'.format(filtered_out_pk), data,
format='json')
response = self.view(request, pk=filtered_out_pk).render() response = self.view(request, pk=filtered_out_pk).render()
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
@ -303,7 +307,8 @@ class TestInstanceView(TestCase):
request = factory.put('/', data, HTTP_ACCEPT='text/html') request = factory.put('/', data, HTTP_ACCEPT='text/html')
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
expected_error = '<span class="help-block">Ensure this field has no more than 100 characters.</span>' expected_error = '<span class="help-block">Ensure this field has no more than 100 characters.</span>'
self.assertIn(expected_error, response.rendered_content.decode('utf-8')) self.assertIn(expected_error,
response.rendered_content.decode('utf-8'))
class TestFKInstanceView(TestCase): class TestFKInstanceView(TestCase):
@ -318,10 +323,8 @@ class TestFKInstanceView(TestCase):
ForeignKeySource(name='source_' + item, target=t).save() ForeignKeySource(name='source_' + item, target=t).save()
self.objects = ForeignKeySource.objects self.objects = ForeignKeySource.objects
self.data = [ self.data = [{'id': obj.id, 'name': obj.name}
{'id': obj.id, 'name': obj.name} for obj in self.objects.all()]
for obj in self.objects.all()
]
self.view = FKInstanceView.as_view() self.view = FKInstanceView.as_view()
@ -339,10 +342,8 @@ class TestOverriddenGetObject(TestCase):
for item in items: for item in items:
BasicModel(text=item).save() BasicModel(text=item).save()
self.objects = BasicModel.objects self.objects = BasicModel.objects
self.data = [ self.data = [{'id': obj.id, 'text': obj.text}
{'id': obj.id, 'text': obj.text} for obj in self.objects.all()]
for obj in self.objects.all()
]
class OverriddenGetObjectView(generics.RetrieveUpdateDestroyAPIView): class OverriddenGetObjectView(generics.RetrieveUpdateDestroyAPIView):
""" """
@ -477,10 +478,8 @@ class TestFilterBackendAppliedToViews(TestCase):
for item in items: for item in items:
BasicModel(text=item).save() BasicModel(text=item).save()
self.objects = BasicModel.objects self.objects = BasicModel.objects
self.data = [ self.data = [{'id': obj.id, 'text': obj.text}
{'id': obj.id, 'text': obj.text} for obj in self.objects.all()]
for obj in self.objects.all()
]
def test_get_root_view_filters_by_name_with_filter_backend(self): def test_get_root_view_filters_by_name_with_filter_backend(self):
""" """
@ -493,7 +492,8 @@ class TestFilterBackendAppliedToViews(TestCase):
self.assertEqual(len(response.data), 1) self.assertEqual(len(response.data), 1)
self.assertEqual(response.data, [{'id': 1, 'text': 'foo'}]) self.assertEqual(response.data, [{'id': 1, 'text': 'foo'}])
def test_get_root_view_filters_out_all_models_with_exclusive_filter_backend(self): def test_get_root_view_filters_out_all_models_with_exclusive_filter_backend(
self):
""" """
GET requests to ListCreateAPIView should return empty list when all models are filtered out. GET requests to ListCreateAPIView should return empty list when all models are filtered out.
""" """
@ -507,17 +507,20 @@ class TestFilterBackendAppliedToViews(TestCase):
""" """
GET requests to RetrieveUpdateDestroyAPIView should raise 404 when model filtered out. GET requests to RetrieveUpdateDestroyAPIView should raise 404 when model filtered out.
""" """
instance_view = InstanceView.as_view(filter_backends=(ExclusiveFilterBackend,)) instance_view = InstanceView.as_view(
filter_backends=(ExclusiveFilterBackend,))
request = factory.get('/1') request = factory.get('/1')
response = instance_view(request, pk=1).render() response = instance_view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertEqual(response.data, {'detail': 'Not found.'}) self.assertEqual(response.data, {'detail': 'Not found.'})
def test_get_instance_view_will_return_single_object_when_filter_does_not_exclude_it(self): def test_get_instance_view_will_return_single_object_when_filter_does_not_exclude_it(
self):
""" """
GET requests to RetrieveUpdateDestroyAPIView should return a single object when not excluded GET requests to RetrieveUpdateDestroyAPIView should return a single object when not excluded
""" """
instance_view = InstanceView.as_view(filter_backends=(InclusiveFilterBackend,)) instance_view = InstanceView.as_view(
filter_backends=(InclusiveFilterBackend,))
request = factory.get('/1') request = factory.get('/1')
response = instance_view(request, pk=1).render() response = instance_view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)

View File

@ -4,11 +4,14 @@ from django.core.validators import MaxValueValidator, MinValueValidator
from django.db import models from django.db import models
from django.test import TestCase from django.test import TestCase
from rest_framework import ( from rest_framework import exceptions, metadata, status, versioning, views
exceptions, metadata, serializers, status, versioning, views from rest_framework.fields import (
CharField, ChoiceField, IntegerField, ListField, NullBooleanField
) )
from rest_framework.relations import PrimaryKeyRelatedField, RelatedField
from rest_framework.renderers import BrowsableAPIRenderer from rest_framework.renderers import BrowsableAPIRenderer
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.serializers import ModelSerializer, Serializer
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
from .models import BasicModel from .models import BasicModel
@ -21,6 +24,7 @@ class TestMetadata:
""" """
OPTIONS requests to views should return a valid 200 response. OPTIONS requests to views should return a valid 200 response.
""" """
class ExampleView(views.APIView): class ExampleView(views.APIView):
"""Example view.""" """Example view."""
pass pass
@ -46,8 +50,9 @@ class TestMetadata:
def test_none_metadata(self): def test_none_metadata(self):
""" """
OPTIONS requests to views where `metadata_class = None` should raise OPTIONS requests to views where `metadata_class = None` should raise
a MethodNotAllowed exception, which will result in an HTTP 405 response. a MethodNotAllowed exception, which will result in an HTTP 405 response
""" """
class ExampleView(views.APIView): class ExampleView(views.APIView):
metadata_class = None metadata_class = None
@ -61,27 +66,29 @@ class TestMetadata:
On generic views OPTIONS should return an 'actions' key with metadata On generic views OPTIONS should return an 'actions' key with metadata
on the fields that may be supplied to PUT and POST requests. on the fields that may be supplied to PUT and POST requests.
""" """
class NestedField(serializers.Serializer):
a = serializers.IntegerField()
b = serializers.IntegerField()
class ExampleSerializer(serializers.Serializer): class NestedField(Serializer):
choice_field = serializers.ChoiceField(['red', 'green', 'blue']) a = IntegerField()
integer_field = serializers.IntegerField( b = IntegerField()
class ExampleSerializer(Serializer):
choice_field = ChoiceField(['red', 'green', 'blue'])
integer_field = IntegerField(
min_value=1, max_value=1000 min_value=1, max_value=1000
) )
char_field = serializers.CharField( char_field = CharField(
required=False, min_length=3, max_length=40 required=False, min_length=3, max_length=40
) )
list_field = serializers.ListField( list_field = ListField(
child=serializers.ListField( child=ListField(
child=serializers.IntegerField() child=IntegerField()
) )
) )
nested_field = NestedField() nested_field = NestedField()
class ExampleView(views.APIView): class ExampleView(views.APIView):
"""Example view.""" """Example view."""
def post(self, request): def post(self, request):
pass pass
@ -179,13 +186,15 @@ class TestMetadata:
If a user does not have global permissions on an action, then any If a user does not have global permissions on an action, then any
metadata associated with it should not be included in OPTION responses. metadata associated with it should not be included in OPTION responses.
""" """
class ExampleSerializer(serializers.Serializer):
choice_field = serializers.ChoiceField(['red', 'green', 'blue']) class ExampleSerializer(Serializer):
integer_field = serializers.IntegerField(max_value=10) choice_field = ChoiceField(['red', 'green', 'blue'])
char_field = serializers.CharField(required=False) integer_field = IntegerField(max_value=10)
char_field = CharField(required=False)
class ExampleView(views.APIView): class ExampleView(views.APIView):
"""Example view.""" """Example view."""
def post(self, request): def post(self, request):
pass pass
@ -209,13 +218,15 @@ class TestMetadata:
If a user does not have object permissions on an action, then any If a user does not have object permissions on an action, then any
metadata associated with it should not be included in OPTION responses. metadata associated with it should not be included in OPTION responses.
""" """
class ExampleSerializer(serializers.Serializer):
choice_field = serializers.ChoiceField(['red', 'green', 'blue']) class ExampleSerializer(Serializer):
integer_field = serializers.IntegerField(max_value=10) choice_field = ChoiceField(['red', 'green', 'blue'])
char_field = serializers.CharField(required=False) integer_field = IntegerField(max_value=10)
char_field = CharField(required=False)
class ExampleView(views.APIView): class ExampleView(views.APIView):
"""Example view.""" """Example view."""
def post(self, request): def post(self, request):
pass pass
@ -243,7 +254,7 @@ class TestMetadata:
def get_serializer(self): def get_serializer(self):
assert hasattr(self.request, 'version') assert hasattr(self.request, 'version')
return serializers.Serializer() return Serializer()
view = ExampleView.as_view() view = ExampleView.as_view()
view(request=request) view(request=request)
@ -257,7 +268,7 @@ class TestMetadata:
def get_serializer(self): def get_serializer(self):
assert hasattr(self.request, 'versioning_scheme') assert hasattr(self.request, 'versioning_scheme')
return serializers.Serializer() return Serializer()
scheme = versioning.QueryParameterVersioning scheme = versioning.QueryParameterVersioning
view = ExampleView.as_view(versioning_class=scheme) view = ExampleView.as_view(versioning_class=scheme)
@ -267,7 +278,7 @@ class TestMetadata:
class TestSimpleMetadataFieldInfo(TestCase): class TestSimpleMetadataFieldInfo(TestCase):
def test_null_boolean_field_info_type(self): def test_null_boolean_field_info_type(self):
options = metadata.SimpleMetadata() options = metadata.SimpleMetadata()
field_info = options.get_field_info(serializers.NullBooleanField()) field_info = options.get_field_info(NullBooleanField())
self.assertEqual(field_info['type'], 'boolean') self.assertEqual(field_info['type'], 'boolean')
def test_related_field_choices(self): def test_related_field_choices(self):
@ -275,7 +286,7 @@ class TestSimpleMetadataFieldInfo(TestCase):
BasicModel.objects.create() BasicModel.objects.create()
with self.assertNumQueries(0): with self.assertNumQueries(0):
field_info = options.get_field_info( field_info = options.get_field_info(
serializers.RelatedField(queryset=BasicModel.objects.all()) RelatedField(queryset=BasicModel.objects.all())
) )
self.assertNotIn('choices', field_info) self.assertNotIn('choices', field_info)
@ -287,16 +298,18 @@ class TestModelSerializerMetadata(TestCase):
on the fields that may be supplied to PUT and POST requests. It should on the fields that may be supplied to PUT and POST requests. It should
not fail when a read_only PrimaryKeyRelatedField is present not fail when a read_only PrimaryKeyRelatedField is present
""" """
class Parent(models.Model): class Parent(models.Model):
integer_field = models.IntegerField(validators=[MinValueValidator(1), MaxValueValidator(1000)]) integer_field = models.IntegerField(
validators=[MinValueValidator(1), MaxValueValidator(1000)])
children = models.ManyToManyField('Child') children = models.ManyToManyField('Child')
name = models.CharField(max_length=100, blank=True, null=True) name = models.CharField(max_length=100, blank=True, null=True)
class Child(models.Model): class Child(models.Model):
name = models.CharField(max_length=100) name = models.CharField(max_length=100)
class ExampleSerializer(serializers.ModelSerializer): class ExampleSerializer(ModelSerializer):
children = serializers.PrimaryKeyRelatedField(read_only=True, many=True) children = PrimaryKeyRelatedField(read_only=True, many=True)
class Meta: class Meta:
model = Parent model = Parent
@ -304,6 +317,7 @@ class TestModelSerializerMetadata(TestCase):
class ExampleView(views.APIView): class ExampleView(views.APIView):
"""Example view.""" """Example view."""
def post(self, request): def post(self, request):
pass pass

View File

@ -17,7 +17,8 @@ class ChildModel(ParentModel):
class AssociatedModel(RESTFrameworkModel): class AssociatedModel(RESTFrameworkModel):
ref = models.OneToOneField(ParentModel, primary_key=True, on_delete=models.CASCADE) ref = models.OneToOneField(ParentModel, primary_key=True,
on_delete=models.CASCADE)
name = models.CharField(max_length=100) name = models.CharField(max_length=100)
@ -36,7 +37,6 @@ class AssociatedModelSerializer(serializers.ModelSerializer):
# Tests # Tests
class InheritedModelSerializationTests(TestCase): class InheritedModelSerializationTests(TestCase):
def test_multitable_inherited_model_fields_as_expected(self): def test_multitable_inherited_model_fields_as_expected(self):
""" """
Assert that the parent pointer field is not included in the fields Assert that the parent pointer field is not included in the fields

View File

@ -5,9 +5,6 @@ from django.test import TestCase
from rest_framework import serializers from rest_framework import serializers
from tests.models import RESTFrameworkModel from tests.models import RESTFrameworkModel
# Models
from tests.test_multitable_inheritance import ChildModel from tests.test_multitable_inheritance import ChildModel
@ -26,7 +23,6 @@ class DerivedModelSerializer(serializers.ModelSerializer):
class ChildAssociatedModelSerializer(serializers.ModelSerializer): class ChildAssociatedModelSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = ChildAssociatedModel model = ChildAssociatedModel
fields = ['id', 'child_name'] fields = ['id', 'child_name']
@ -34,7 +30,6 @@ class ChildAssociatedModelSerializer(serializers.ModelSerializer):
# Tests # Tests
class InheritedModelSerializationTests(TestCase): class InheritedModelSerializationTests(TestCase):
def test_multitable_inherited_model_fields_as_expected(self): def test_multitable_inherited_model_fields_as_expected(self):
""" """
Assert that the parent pointer field is not included in the fields Assert that the parent pointer field is not included in the fields

View File

@ -65,27 +65,33 @@ empty_list_view = EmptyListView.as_view()
def basic_auth_header(username, password): def basic_auth_header(username, password):
credentials = ('%s:%s' % (username, password)) credentials = ('%s:%s' % (username, password))
base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING) base64_credentials = base64.b64encode(
credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING)
return 'Basic %s' % base64_credentials return 'Basic %s' % base64_credentials
class ModelPermissionsIntegrationTests(TestCase): class ModelPermissionsIntegrationTests(TestCase):
def setUp(self): def setUp(self):
User.objects.create_user('disallowed', 'disallowed@example.com', 'password') User.objects.create_user('disallowed', 'disallowed@example.com',
user = User.objects.create_user('permitted', 'permitted@example.com', 'password') 'password')
user = User.objects.create_user('permitted', 'permitted@example.com',
'password')
set_many(user, 'user_permissions', [ set_many(user, 'user_permissions', [
Permission.objects.get(codename='add_basicmodel'), Permission.objects.get(codename='add_basicmodel'),
Permission.objects.get(codename='change_basicmodel'), Permission.objects.get(codename='change_basicmodel'),
Permission.objects.get(codename='delete_basicmodel') Permission.objects.get(codename='delete_basicmodel')
]) ])
user = User.objects.create_user('updateonly', 'updateonly@example.com', 'password') user = User.objects.create_user('updateonly', 'updateonly@example.com',
'password')
set_many(user, 'user_permissions', [ set_many(user, 'user_permissions', [
Permission.objects.get(codename='change_basicmodel'), Permission.objects.get(codename='change_basicmodel'),
]) ])
self.permitted_credentials = basic_auth_header('permitted', 'password') self.permitted_credentials = basic_auth_header('permitted', 'password')
self.disallowed_credentials = basic_auth_header('disallowed', 'password') self.disallowed_credentials = basic_auth_header('disallowed',
self.updateonly_credentials = basic_auth_header('updateonly', 'password') 'password')
self.updateonly_credentials = basic_auth_header('updateonly',
'password')
BasicModel(text='foo').save() BasicModel(text='foo').save()
@ -120,7 +126,8 @@ class ModelPermissionsIntegrationTests(TestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_has_delete_permissions(self): def test_has_delete_permissions(self):
request = factory.delete('/1', HTTP_AUTHORIZATION=self.permitted_credentials) request = factory.delete('/1',
HTTP_AUTHORIZATION=self.permitted_credentials)
response = instance_view(request, pk=1) response = instance_view(request, pk=1)
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
@ -137,7 +144,8 @@ class ModelPermissionsIntegrationTests(TestCase):
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_does_not_have_delete_permissions(self): def test_does_not_have_delete_permissions(self):
request = factory.delete('/1', HTTP_AUTHORIZATION=self.disallowed_credentials) request = factory.delete(
'/1', HTTP_AUTHORIZATION=self.disallowed_credentials)
response = instance_view(request, pk=1) response = instance_view(request, pk=1)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
@ -196,7 +204,8 @@ class ModelPermissionsIntegrationTests(TestCase):
self.assertEqual(list(response.data['actions'].keys()), ['PUT']) self.assertEqual(list(response.data['actions'].keys()), ['PUT'])
def test_empty_view_does_not_assert(self): def test_empty_view_does_not_assert(self):
request = factory.get('/1', HTTP_AUTHORIZATION=self.permitted_credentials) request = factory.get('/1',
HTTP_AUTHORIZATION=self.permitted_credentials)
response = empty_list_view(request, pk=1) response = empty_list_view(request, pk=1)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -237,6 +246,7 @@ class ObjectPermissionInstanceView(generics.RetrieveUpdateDestroyAPIView):
authentication_classes = [authentication.BasicAuthentication] authentication_classes = [authentication.BasicAuthentication]
permission_classes = [ViewObjectPermissions] permission_classes = [ViewObjectPermissions]
object_permissions_view = ObjectPermissionInstanceView.as_view() object_permissions_view = ObjectPermissionInstanceView.as_view()
@ -246,10 +256,12 @@ class ObjectPermissionListView(generics.ListAPIView):
authentication_classes = [authentication.BasicAuthentication] authentication_classes = [authentication.BasicAuthentication]
permission_classes = [ViewObjectPermissions] permission_classes = [ViewObjectPermissions]
object_permissions_list_view = ObjectPermissionListView.as_view() object_permissions_list_view = ObjectPermissionListView.as_view()
class GetQuerysetObjectPermissionInstanceView(generics.RetrieveUpdateDestroyAPIView): class GetQuerysetObjectPermissionInstanceView(generics.
RetrieveUpdateDestroyAPIView):
serializer_class = BasicPermSerializer serializer_class = BasicPermSerializer
authentication_classes = [authentication.BasicAuthentication] authentication_classes = [authentication.BasicAuthentication]
permission_classes = [ViewObjectPermissions] permission_classes = [ViewObjectPermissions]
@ -258,7 +270,8 @@ class GetQuerysetObjectPermissionInstanceView(generics.RetrieveUpdateDestroyAPIV
return BasicPermModel.objects.all() return BasicPermModel.objects.all()
get_queryset_object_permissions_view = GetQuerysetObjectPermissionInstanceView.as_view() get_queryset_object_permissions_view = \
GetQuerysetObjectPermissionInstanceView.as_view()
@unittest.skipUnless(guardian, 'django-guardian not installed') @unittest.skipUnless(guardian, 'django-guardian not installed')
@ -266,16 +279,20 @@ class ObjectPermissionsIntegrationTests(TestCase):
""" """
Integration tests for the object level permissions API. Integration tests for the object level permissions API.
""" """
def setUp(self): def setUp(self):
from guardian.shortcuts import assign_perm from guardian.shortcuts import assign_perm
# create users # create users
create = User.objects.create_user create = User.objects.create_user
users = { users = {
'fullaccess': create('fullaccess', 'fullaccess@example.com', 'password'), 'fullaccess': create('fullaccess', 'fullaccess@example.com',
'password'),
'readonly': create('readonly', 'readonly@example.com', 'password'), 'readonly': create('readonly', 'readonly@example.com', 'password'),
'writeonly': create('writeonly', 'writeonly@example.com', 'password'), 'writeonly': create('writeonly', 'writeonly@example.com',
'deleteonly': create('deleteonly', 'deleteonly@example.com', 'password'), 'password'),
'deleteonly': create('deleteonly', 'deleteonly@example.com',
'password'),
} }
# give everyone model level permissions, as we are not testing those # give everyone model level permissions, as we are not testing those
@ -310,16 +327,19 @@ class ObjectPermissionsIntegrationTests(TestCase):
self.credentials = {} self.credentials = {}
for user in users.values(): for user in users.values():
self.credentials[user.username] = basic_auth_header(user.username, 'password') self.credentials[user.username] = basic_auth_header(user.username,
'password')
# Delete # Delete
def test_can_delete_permissions(self): def test_can_delete_permissions(self):
request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['deleteonly']) request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials[
'deleteonly'])
response = object_permissions_view(request, pk='1') response = object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
def test_cannot_delete_permissions(self): def test_cannot_delete_permissions(self):
request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['readonly']) request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials[
'readonly'])
response = object_permissions_view(request, pk='1') response = object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
@ -351,12 +371,14 @@ class ObjectPermissionsIntegrationTests(TestCase):
# Read # Read
def test_can_read_permissions(self): def test_can_read_permissions(self):
request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['readonly']) request = factory.get('/1',
HTTP_AUTHORIZATION=self.credentials['readonly'])
response = object_permissions_view(request, pk='1') response = object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_cannot_read_permissions(self): def test_cannot_read_permissions(self):
request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['writeonly']) request = factory.get('/1',
HTTP_AUTHORIZATION=self.credentials['writeonly'])
response = object_permissions_view(request, pk='1') response = object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
@ -365,21 +387,26 @@ class ObjectPermissionsIntegrationTests(TestCase):
same as ``test_can_read_permissions`` but with a view same as ``test_can_read_permissions`` but with a view
that rely on ``.get_queryset()`` instead of ``.queryset``. that rely on ``.get_queryset()`` instead of ``.queryset``.
""" """
request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['readonly']) request = factory.get('/1',
HTTP_AUTHORIZATION=self.credentials['readonly'])
response = get_queryset_object_permissions_view(request, pk='1') response = get_queryset_object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
# Read list # Read list
def test_can_read_list_permissions(self): def test_can_read_list_permissions(self):
request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['readonly']) request = factory.get('/',
object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,) HTTP_AUTHORIZATION=self.credentials['readonly'])
object_permissions_list_view.cls.filter_backends = (
DjangoObjectPermissionsFilter,)
response = object_permissions_list_view(request) response = object_permissions_list_view(request)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data[0].get('id'), 1) self.assertEqual(response.data[0].get('id'), 1)
def test_cannot_read_list_permissions(self): def test_cannot_read_list_permissions(self):
request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['writeonly']) request = factory.get('/',
object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,) HTTP_AUTHORIZATION=self.credentials['writeonly'])
object_permissions_list_view.cls.filter_backends = (
DjangoObjectPermissionsFilter,)
response = object_permissions_list_view(request) response = object_permissions_list_view(request)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertListEqual(response.data, []) self.assertListEqual(response.data, [])
@ -429,6 +456,7 @@ class DeniedObjectView(PermissionInstanceView):
class DeniedObjectViewWithDetail(PermissionInstanceView): class DeniedObjectViewWithDetail(PermissionInstanceView):
permission_classes = (BasicObjectPermWithDetail,) permission_classes = (BasicObjectPermWithDetail,)
denied_view = DeniedView.as_view() denied_view = DeniedView.as_view()
denied_view_with_detail = DeniedViewWithDetail.as_view() denied_view_with_detail = DeniedViewWithDetail.as_view()
@ -441,31 +469,33 @@ denied_object_view_with_detail = DeniedObjectViewWithDetail.as_view()
class CustomPermissionsTests(TestCase): class CustomPermissionsTests(TestCase):
def setUp(self): def setUp(self):
BasicModel(text='foo').save() BasicModel(text='foo').save()
User.objects.create_user('username', 'username@example.com', 'password') User.objects.create_user('username', 'username@example.com',
'password')
credentials = basic_auth_header('username', 'password') credentials = basic_auth_header('username', 'password')
self.request = factory.get('/1', format='json', HTTP_AUTHORIZATION=credentials) self.request = factory.get('/1', format='json',
HTTP_AUTHORIZATION=credentials)
self.custom_message = 'Custom: You cannot access this resource' self.custom_message = 'Custom: You cannot access this resource'
def test_permission_denied(self): def test_permission_denied(self):
response = denied_view(self.request, pk=1) response = denied_view(self.request, pk=1)
detail = response.data.get('detail') detail = response.data.get('detail')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertNotEqual(detail, self.custom_message) self.assertNotEqual(detail, self.custom_message)
def test_permission_denied_with_custom_detail(self): def test_permission_denied_with_custom_detail(self):
response = denied_view_with_detail(self.request, pk=1) response = denied_view_with_detail(self.request, pk=1)
detail = response.data.get('detail') detail = response.data.get('detail')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(detail, self.custom_message) self.assertEqual(detail, self.custom_message)
def test_permission_denied_for_object(self): def test_permission_denied_for_object(self):
response = denied_object_view(self.request, pk=1) response = denied_object_view(self.request, pk=1)
detail = response.data.get('detail') detail = response.data.get('detail')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertNotEqual(detail, self.custom_message) self.assertNotEqual(detail, self.custom_message)
def test_permission_denied_for_object_with_custom_detail(self): def test_permission_denied_for_object_with_custom_detail(self):
response = denied_object_view_with_detail(self.request, pk=1) response = denied_object_view_with_detail(self.request, pk=1)
detail = response.data.get('detail') detail = response.data.get('detail')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(detail, self.custom_message) self.assertEqual(detail, self.custom_message)

View File

@ -1,11 +1,14 @@
import uuid import uuid
import pytest import pytest
from django.core.exceptions import ImproperlyConfigured
from django.utils.datastructures import MultiValueDict from django.utils.datastructures import MultiValueDict
from rest_framework import serializers from rest_framework.fields import UUIDField, ValidationError, empty
from rest_framework.fields import empty from rest_framework.relations import (
Hyperlink, HyperlinkedIdentityField, HyperlinkedRelatedField,
ImproperlyConfigured, PrimaryKeyRelatedField, SlugRelatedField,
StringRelatedField
)
from rest_framework.test import APISimpleTestCase from rest_framework.test import APISimpleTestCase
from .utils import ( from .utils import (
@ -16,7 +19,7 @@ from .utils import (
class TestStringRelatedField(APISimpleTestCase): class TestStringRelatedField(APISimpleTestCase):
def setUp(self): def setUp(self):
self.instance = MockObject(pk=1, name='foo') self.instance = MockObject(pk=1, name='foo')
self.field = serializers.StringRelatedField() self.field = StringRelatedField()
def test_string_related_representation(self): def test_string_related_representation(self):
representation = self.field.to_representation(self.instance) representation = self.field.to_representation(self.instance)
@ -31,20 +34,20 @@ class TestPrimaryKeyRelatedField(APISimpleTestCase):
MockObject(pk=3, name='baz') MockObject(pk=3, name='baz')
]) ])
self.instance = self.queryset.items[2] self.instance = self.queryset.items[2]
self.field = serializers.PrimaryKeyRelatedField(queryset=self.queryset) self.field = PrimaryKeyRelatedField(queryset=self.queryset)
def test_pk_related_lookup_exists(self): def test_pk_related_lookup_exists(self):
instance = self.field.to_internal_value(self.instance.pk) instance = self.field.to_internal_value(self.instance.pk)
assert instance is self.instance assert instance is self.instance
def test_pk_related_lookup_does_not_exist(self): def test_pk_related_lookup_does_not_exist(self):
with pytest.raises(serializers.ValidationError) as excinfo: with pytest.raises(ValidationError) as excinfo:
self.field.to_internal_value(4) self.field.to_internal_value(4)
msg = excinfo.value.detail[0] msg = excinfo.value.detail[0]
assert msg == 'Invalid pk "4" - object does not exist.' assert msg == 'Invalid pk "4" - object does not exist.'
def test_pk_related_lookup_invalid_type(self): def test_pk_related_lookup_invalid_type(self):
with pytest.raises(serializers.ValidationError) as excinfo: with pytest.raises(ValidationError) as excinfo:
self.field.to_internal_value(BadType()) self.field.to_internal_value(BadType())
msg = excinfo.value.detail[0] msg = excinfo.value.detail[0]
assert msg == 'Incorrect type. Expected pk value, received BadType.' assert msg == 'Incorrect type. Expected pk value, received BadType.'
@ -54,7 +57,7 @@ class TestPrimaryKeyRelatedField(APISimpleTestCase):
assert representation == self.instance.pk assert representation == self.instance.pk
def test_explicit_many_false(self): def test_explicit_many_false(self):
field = serializers.PrimaryKeyRelatedField(queryset=self.queryset, many=False) field = PrimaryKeyRelatedField(queryset=self.queryset, many=False)
instance = field.to_internal_value(self.instance.pk) instance = field.to_internal_value(self.instance.pk)
assert instance is self.instance assert instance is self.instance
@ -67,9 +70,9 @@ class TestProxiedPrimaryKeyRelatedField(APISimpleTestCase):
MockObject(pk=uuid.UUID(int=2), name='baz') MockObject(pk=uuid.UUID(int=2), name='baz')
]) ])
self.instance = self.queryset.items[2] self.instance = self.queryset.items[2]
self.field = serializers.PrimaryKeyRelatedField( self.field = PrimaryKeyRelatedField(
queryset=self.queryset, queryset=self.queryset,
pk_field=serializers.UUIDField(format='int') pk_field=UUIDField(format='int')
) )
def test_pk_related_lookup_exists(self): def test_pk_related_lookup_exists(self):
@ -77,7 +80,7 @@ class TestProxiedPrimaryKeyRelatedField(APISimpleTestCase):
assert instance is self.instance assert instance is self.instance
def test_pk_related_lookup_does_not_exist(self): def test_pk_related_lookup_does_not_exist(self):
with pytest.raises(serializers.ValidationError) as excinfo: with pytest.raises(ValidationError) as excinfo:
self.field.to_internal_value(4) self.field.to_internal_value(4)
msg = excinfo.value.detail[0] msg = excinfo.value.detail[0]
assert msg == 'Invalid pk "00000000-0000-0000-0000-000000000004" - object does not exist.' assert msg == 'Invalid pk "00000000-0000-0000-0000-000000000004" - object does not exist.'
@ -89,7 +92,7 @@ class TestProxiedPrimaryKeyRelatedField(APISimpleTestCase):
class TestHyperlinkedRelatedField(APISimpleTestCase): class TestHyperlinkedRelatedField(APISimpleTestCase):
def setUp(self): def setUp(self):
self.field = serializers.HyperlinkedRelatedField( self.field = HyperlinkedRelatedField(
view_name='example', read_only=True) view_name='example', read_only=True)
self.field.reverse = mock_reverse self.field.reverse = mock_reverse
self.field._context = {'request': True} self.field._context = {'request': True}
@ -102,7 +105,7 @@ class TestHyperlinkedRelatedField(APISimpleTestCase):
class TestHyperlinkedIdentityField(APISimpleTestCase): class TestHyperlinkedIdentityField(APISimpleTestCase):
def setUp(self): def setUp(self):
self.instance = MockObject(pk=1, name='foo') self.instance = MockObject(pk=1, name='foo')
self.field = serializers.HyperlinkedIdentityField(view_name='example') self.field = HyperlinkedIdentityField(view_name='example')
self.field.reverse = mock_reverse self.field.reverse = mock_reverse
self.field._context = {'request': True} self.field._context = {'request': True}
@ -135,14 +138,15 @@ class TestHyperlinkedIdentityFieldWithFormat(APISimpleTestCase):
Tests for a hyperlinked identity field that has a `format` set, Tests for a hyperlinked identity field that has a `format` set,
which enforces that alternate formats are never linked too. which enforces that alternate formats are never linked too.
Eg. If your API includes some endpoints that accept both `.xml` and `.json`, Eg. If your API includes some endpoints that accept both `.xml` and `.json`
but other endpoints that only accept `.json`, we allow for hyperlinked but other endpoints only accept `.json`, we allow for hyperlinked
relationships that enforce only a single suffix type. relationships that enforce only a single suffix type.
""" """
def setUp(self): def setUp(self):
self.instance = MockObject(pk=1, name='foo') self.instance = MockObject(pk=1, name='foo')
self.field = serializers.HyperlinkedIdentityField(view_name='example', format='json') self.field = HyperlinkedIdentityField(view_name='example',
format='json')
self.field.reverse = mock_reverse self.field.reverse = mock_reverse
self.field._context = {'request': True} self.field._context = {'request': True}
@ -164,7 +168,7 @@ class TestSlugRelatedField(APISimpleTestCase):
MockObject(pk=3, name='baz') MockObject(pk=3, name='baz')
]) ])
self.instance = self.queryset.items[2] self.instance = self.queryset.items[2]
self.field = serializers.SlugRelatedField( self.field = SlugRelatedField(
slug_field='name', queryset=self.queryset slug_field='name', queryset=self.queryset
) )
@ -173,13 +177,13 @@ class TestSlugRelatedField(APISimpleTestCase):
assert instance is self.instance assert instance is self.instance
def test_slug_related_lookup_does_not_exist(self): def test_slug_related_lookup_does_not_exist(self):
with pytest.raises(serializers.ValidationError) as excinfo: with pytest.raises(ValidationError) as excinfo:
self.field.to_internal_value('doesnotexist') self.field.to_internal_value('doesnotexist')
msg = excinfo.value.detail[0] msg = excinfo.value.detail[0]
assert msg == 'Object with name=doesnotexist does not exist.' assert msg == 'Object with name=doesnotexist does not exist.'
def test_slug_related_lookup_invalid_type(self): def test_slug_related_lookup_invalid_type(self):
with pytest.raises(serializers.ValidationError) as excinfo: with pytest.raises(ValidationError) as excinfo:
self.field.to_internal_value(BadType()) self.field.to_internal_value(BadType())
msg = excinfo.value.detail[0] msg = excinfo.value.detail[0]
assert msg == 'Invalid value.' assert msg == 'Invalid value.'
@ -191,7 +195,7 @@ class TestSlugRelatedField(APISimpleTestCase):
def test_overriding_get_queryset(self): def test_overriding_get_queryset(self):
qs = self.queryset qs = self.queryset
class NoQuerySetSlugRelatedField(serializers.SlugRelatedField): class NoQuerySetSlugRelatedField(SlugRelatedField):
def get_queryset(self): def get_queryset(self):
return qs return qs
@ -202,7 +206,7 @@ class TestSlugRelatedField(APISimpleTestCase):
class TestManyRelatedField(APISimpleTestCase): class TestManyRelatedField(APISimpleTestCase):
def setUp(self): def setUp(self):
self.instance = MockObject(pk=1, name='foo') self.instance = MockObject(pk=1, name='foo')
self.field = serializers.StringRelatedField(many=True) self.field = StringRelatedField(many=True)
self.field.field_name = 'foo' self.field.field_name = 'foo'
def test_get_value_regular_dictionary_full(self): def test_get_value_regular_dictionary_full(self):
@ -232,7 +236,7 @@ class TestManyRelatedField(APISimpleTestCase):
class TestHyperlink: class TestHyperlink:
def setup(self): def setup(self):
self.default_hyperlink = serializers.Hyperlink('http://example.com', 'test') self.default_hyperlink = Hyperlink('http://example.com', 'test')
def test_can_be_pickled(self): def test_can_be_pickled(self):
import pickle import pickle

View File

@ -8,7 +8,8 @@ from django.db import models
from django.test import TestCase from django.test import TestCase
from django.utils.encoding import python_2_unicode_compatible from django.utils.encoding import python_2_unicode_compatible
from rest_framework import serializers from rest_framework.relations import StringRelatedField
from rest_framework.serializers import ModelSerializer
@python_2_unicode_compatible @python_2_unicode_compatible
@ -51,7 +52,8 @@ class Note(models.Model):
class TestGenericRelations(TestCase): class TestGenericRelations(TestCase):
def setUp(self): def setUp(self):
self.bookmark = Bookmark.objects.create(url='https://www.djangoproject.com/') self.bookmark = Bookmark.objects.create(
url='https://www.djangoproject.com/')
Tag.objects.create(tagged_item=self.bookmark, tag='django') Tag.objects.create(tagged_item=self.bookmark, tag='django')
Tag.objects.create(tagged_item=self.bookmark, tag='python') Tag.objects.create(tagged_item=self.bookmark, tag='python')
self.note = Note.objects.create(text='Remember the milk') self.note = Note.objects.create(text='Remember the milk')
@ -63,8 +65,8 @@ class TestGenericRelations(TestCase):
IE. A reverse generic relationship. IE. A reverse generic relationship.
""" """
class BookmarkSerializer(serializers.ModelSerializer): class BookmarkSerializer(ModelSerializer):
tags = serializers.StringRelatedField(many=True) tags = StringRelatedField(many=True)
class Meta: class Meta:
model = Bookmark model = Bookmark
@ -83,8 +85,8 @@ class TestGenericRelations(TestCase):
IE. A forward generic relationship. IE. A forward generic relationship.
""" """
class TagSerializer(serializers.ModelSerializer): class TagSerializer(ModelSerializer):
tagged_item = serializers.StringRelatedField() tagged_item = StringRelatedField()
class Meta: class Meta:
model = Tag model = Tag

View File

@ -11,7 +11,8 @@ from tests.models import (
) )
factory = APIRequestFactory() factory = APIRequestFactory()
request = factory.get('/') # Just to ensure we have a request in the serializer context request = factory.get(
'/') # Just to ensure we have a request in the serializer context
def dummy_view(request, pk): def dummy_view(request, pk):
@ -20,13 +21,20 @@ def dummy_view(request, pk):
urlpatterns = [ urlpatterns = [
url(r'^dummyurl/(?P<pk>[0-9]+)/$', dummy_view, name='dummy-url'), url(r'^dummyurl/(?P<pk>[0-9]+)/$', dummy_view, name='dummy-url'),
url(r'^manytomanysource/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanysource-detail'), url(r'^manytomanysource/(?P<pk>[0-9]+)/$', dummy_view,
url(r'^manytomanytarget/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanytarget-detail'), name='manytomanysource-detail'),
url(r'^foreignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeysource-detail'), url(r'^manytomanytarget/(?P<pk>[0-9]+)/$', dummy_view,
url(r'^foreignkeytarget/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeytarget-detail'), name='manytomanytarget-detail'),
url(r'^nullableforeignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableforeignkeysource-detail'), url(r'^foreignkeysource/(?P<pk>[0-9]+)/$', dummy_view,
url(r'^onetoonetarget/(?P<pk>[0-9]+)/$', dummy_view, name='onetoonetarget-detail'), name='foreignkeysource-detail'),
url(r'^nullableonetoonesource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableonetoonesource-detail'), url(r'^foreignkeytarget/(?P<pk>[0-9]+)/$', dummy_view,
name='foreignkeytarget-detail'),
url(r'^nullableforeignkeysource/(?P<pk>[0-9]+)/$', dummy_view,
name='nullableforeignkeysource-detail'),
url(r'^onetoonetarget/(?P<pk>[0-9]+)/$', dummy_view,
name='onetoonetarget-detail'),
url(r'^nullableonetoonesource/(?P<pk>[0-9]+)/$', dummy_view,
name='nullableonetoonesource-detail'),
] ]
@ -57,7 +65,8 @@ class ForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):
# Nullable ForeignKey # Nullable ForeignKey
class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer): class NullableForeignKeySourceSerializer(serializers.
HyperlinkedModelSerializer):
class Meta: class Meta:
model = NullableForeignKeySource model = NullableForeignKeySource
fields = ('url', 'name', 'target') fields = ('url', 'name', 'target')
@ -84,83 +93,141 @@ class HyperlinkedManyToManyTests(TestCase):
def test_relative_hyperlinks(self): def test_relative_hyperlinks(self):
queryset = ManyToManySource.objects.all() queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': None}) serializer = ManyToManySourceSerializer(queryset, many=True,
context={'request': None})
expected = [ expected = [
{'url': '/manytomanysource/1/', 'name': 'source-1', 'targets': ['/manytomanytarget/1/']}, {'url': '/manytomanysource/1/', 'name': 'source-1',
{'url': '/manytomanysource/2/', 'name': 'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']}, 'targets': ['/manytomanytarget/1/']},
{'url': '/manytomanysource/3/', 'name': 'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']} {'url': '/manytomanysource/2/', 'name': 'source-2',
'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']},
{'url': '/manytomanysource/3/', 'name': 'source-3',
'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/',
'/manytomanytarget/3/']}
] ]
with self.assertNumQueries(4): with self.assertNumQueries(4):
self.assertEqual(serializer.data, expected) self.assertEqual(serializer.data, expected)
def test_many_to_many_retrieve(self): def test_many_to_many_retrieve(self):
queryset = ManyToManySource.objects.all() queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) serializer = ManyToManySourceSerializer(queryset, many=True,
context={'request': request})
expected = [ expected = [
{'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']}, {'url': 'http://testserver/manytomanysource/1/',
{'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, 'name': 'source-1',
{'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} 'targets': ['http://testserver/manytomanytarget/1/']},
{'url': 'http://testserver/manytomanysource/2/',
'name': 'source-2',
'targets': ['http://testserver/manytomanytarget/1/',
'http://testserver/manytomanytarget/2/']},
{'url': 'http://testserver/manytomanysource/3/',
'name': 'source-3',
'targets': ['http://testserver/manytomanytarget/1/',
'http://testserver/manytomanytarget/2/',
'http://testserver/manytomanytarget/3/']}
] ]
with self.assertNumQueries(4): with self.assertNumQueries(4):
self.assertEqual(serializer.data, expected) self.assertEqual(serializer.data, expected)
def test_many_to_many_retrieve_prefetch_related(self): def test_many_to_many_retrieve_prefetch_related(self):
queryset = ManyToManySource.objects.all().prefetch_related('targets') queryset = ManyToManySource.objects.all().prefetch_related('targets')
serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) serializer = ManyToManySourceSerializer(queryset, many=True,
context={'request': request})
with self.assertNumQueries(2): with self.assertNumQueries(2):
serializer.data serializer.data
def test_reverse_many_to_many_retrieve(self): def test_reverse_many_to_many_retrieve(self):
queryset = ManyToManyTarget.objects.all() queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request}) serializer = ManyToManyTargetSerializer(queryset, many=True,
context={'request': request})
expected = [ expected = [
{'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, {'url': 'http://testserver/manytomanytarget/1/',
{'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, 'name': 'target-1',
{'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']} 'sources': ['http://testserver/manytomanysource/1/',
'http://testserver/manytomanysource/2/',
'http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/2/',
'name': 'target-2',
'sources': ['http://testserver/manytomanysource/2/',
'http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/3/',
'name': 'target-3',
'sources': ['http://testserver/manytomanysource/3/']}
] ]
with self.assertNumQueries(4): with self.assertNumQueries(4):
self.assertEqual(serializer.data, expected) self.assertEqual(serializer.data, expected)
def test_many_to_many_update(self): def test_many_to_many_update(self):
data = {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} data = {'url': 'http://testserver/manytomanysource/1/',
'name': 'source-1',
'targets': ['http://testserver/manytomanytarget/1/',
'http://testserver/manytomanytarget/2/',
'http://testserver/manytomanytarget/3/']}
instance = ManyToManySource.objects.get(pk=1) instance = ManyToManySource.objects.get(pk=1)
serializer = ManyToManySourceSerializer(instance, data=data, context={'request': request}) serializer = ManyToManySourceSerializer(instance, data=data,
context={'request': request})
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
serializer.save() serializer.save()
self.assertEqual(serializer.data, data) self.assertEqual(serializer.data, data)
# Ensure source 1 is updated, and everything else is as expected # Ensure source 1 is updated, and everything else is as expected
queryset = ManyToManySource.objects.all() queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) serializer = ManyToManySourceSerializer(queryset, many=True,
context={'request': request})
expected = [ expected = [
{'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}, {'url': 'http://testserver/manytomanysource/1/',
{'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, 'name': 'source-1',
{'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']} 'targets': ['http://testserver/manytomanytarget/1/',
'http://testserver/manytomanytarget/2/',
'http://testserver/manytomanytarget/3/']},
{'url': 'http://testserver/manytomanysource/2/',
'name': 'source-2',
'targets': ['http://testserver/manytomanytarget/1/',
'http://testserver/manytomanytarget/2/']},
{'url': 'http://testserver/manytomanysource/3/',
'name': 'source-3',
'targets': ['http://testserver/manytomanytarget/1/',
'http://testserver/manytomanytarget/2/',
'http://testserver/manytomanytarget/3/']}
] ]
self.assertEqual(serializer.data, expected) self.assertEqual(serializer.data, expected)
def test_reverse_many_to_many_update(self): def test_reverse_many_to_many_update(self):
data = {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']} data = {'url': 'http://testserver/manytomanytarget/1/',
'name': 'target-1',
'sources': ['http://testserver/manytomanysource/1/']}
instance = ManyToManyTarget.objects.get(pk=1) instance = ManyToManyTarget.objects.get(pk=1)
serializer = ManyToManyTargetSerializer(instance, data=data, context={'request': request}) serializer = ManyToManyTargetSerializer(instance, data=data,
context={'request': request})
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
serializer.save() serializer.save()
self.assertEqual(serializer.data, data) self.assertEqual(serializer.data, data)
# Ensure target 1 is updated, and everything else is as expected # Ensure target 1 is updated, and everything else is as expected
queryset = ManyToManyTarget.objects.all() queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request}) serializer = ManyToManyTargetSerializer(queryset, many=True,
context={'request': request})
expected = [ expected = [
{'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']}, {'url': 'http://testserver/manytomanytarget/1/',
{'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, 'name': 'target-1',
{'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']} 'sources': ['http://testserver/manytomanysource/1/']},
{'url': 'http://testserver/manytomanytarget/2/',
'name': 'target-2',
'sources': ['http://testserver/manytomanysource/2/',
'http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/3/',
'name': 'target-3',
'sources': ['http://testserver/manytomanysource/3/']}
] ]
self.assertEqual(serializer.data, expected) self.assertEqual(serializer.data, expected)
def test_many_to_many_create(self): def test_many_to_many_create(self):
data = {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']} data = {'url': 'http://testserver/manytomanysource/4/',
serializer = ManyToManySourceSerializer(data=data, context={'request': request}) 'name': 'source-4',
'targets': ['http://testserver/manytomanytarget/1/',
'http://testserver/manytomanytarget/3/']}
serializer = ManyToManySourceSerializer(data=data,
context={'request': request})
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
obj = serializer.save() obj = serializer.save()
self.assertEqual(serializer.data, data) self.assertEqual(serializer.data, data)
@ -168,18 +235,35 @@ class HyperlinkedManyToManyTests(TestCase):
# Ensure source 4 is added, and everything else is as expected # Ensure source 4 is added, and everything else is as expected
queryset = ManyToManySource.objects.all() queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request}) serializer = ManyToManySourceSerializer(queryset, many=True,
context={'request': request})
expected = [ expected = [
{'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']}, {'url': 'http://testserver/manytomanysource/1/',
{'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']}, 'name': 'source-1',
{'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}, 'targets': ['http://testserver/manytomanytarget/1/']},
{'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']} {'url': 'http://testserver/manytomanysource/2/',
'name': 'source-2',
'targets': ['http://testserver/manytomanytarget/1/',
'http://testserver/manytomanytarget/2/']},
{'url': 'http://testserver/manytomanysource/3/',
'name': 'source-3',
'targets': ['http://testserver/manytomanytarget/1/',
'http://testserver/manytomanytarget/2/',
'http://testserver/manytomanytarget/3/']},
{'url': 'http://testserver/manytomanysource/4/',
'name': 'source-4',
'targets': ['http://testserver/manytomanytarget/1/',
'http://testserver/manytomanytarget/3/']}
] ]
self.assertEqual(serializer.data, expected) self.assertEqual(serializer.data, expected)
def test_reverse_many_to_many_create(self): def test_reverse_many_to_many_create(self):
data = {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']} data = {'url': 'http://testserver/manytomanytarget/4/',
serializer = ManyToManyTargetSerializer(data=data, context={'request': request}) 'name': 'target-4',
'sources': ['http://testserver/manytomanysource/1/',
'http://testserver/manytomanysource/3/']}
serializer = ManyToManyTargetSerializer(data=data,
context={'request': request})
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
obj = serializer.save() obj = serializer.save()
self.assertEqual(serializer.data, data) self.assertEqual(serializer.data, data)
@ -187,12 +271,25 @@ class HyperlinkedManyToManyTests(TestCase):
# Ensure target 4 is added, and everything else is as expected # Ensure target 4 is added, and everything else is as expected
queryset = ManyToManyTarget.objects.all() queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request}) serializer = ManyToManyTargetSerializer(queryset, many=True,
context={'request': request})
expected = [ expected = [
{'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, {'url': 'http://testserver/manytomanytarget/1/',
{'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']}, 'name': 'target-1',
{'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}, 'sources': ['http://testserver/manytomanysource/1/',
{'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']} 'http://testserver/manytomanysource/2/',
'http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/2/',
'name': 'target-2',
'sources': ['http://testserver/manytomanysource/2/',
'http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/3/',
'name': 'target-3',
'sources': ['http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/4/',
'name': 'target-4',
'sources': ['http://testserver/manytomanysource/1/',
'http://testserver/manytomanysource/3/']}
] ]
self.assertEqual(serializer.data, expected) self.assertEqual(serializer.data, expected)
@ -210,62 +307,99 @@ class HyperlinkedForeignKeyTests(TestCase):
def test_foreign_key_retrieve(self): def test_foreign_key_retrieve(self):
queryset = ForeignKeySource.objects.all() queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request}) serializer = ForeignKeySourceSerializer(queryset, many=True,
context={'request': request})
expected = [ expected = [
{'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/foreignkeysource/1/',
{'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, 'name': 'source-1',
{'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'} 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/2/',
'name': 'source-2',
'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/3/',
'name': 'source-3',
'target': 'http://testserver/foreignkeytarget/1/'}
] ]
with self.assertNumQueries(1): with self.assertNumQueries(1):
self.assertEqual(serializer.data, expected) self.assertEqual(serializer.data, expected)
def test_reverse_foreign_key_retrieve(self): def test_reverse_foreign_key_retrieve(self):
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) serializer = ForeignKeyTargetSerializer(queryset, many=True,
context={'request': request})
expected = [ expected = [
{'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']}, {'url': 'http://testserver/foreignkeytarget/1/',
{'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, 'name': 'target-1',
'sources': ['http://testserver/foreignkeysource/1/',
'http://testserver/foreignkeysource/2/',
'http://testserver/foreignkeysource/3/']},
{'url': 'http://testserver/foreignkeytarget/2/',
'name': 'target-2', 'sources': []},
] ]
with self.assertNumQueries(3): with self.assertNumQueries(3):
self.assertEqual(serializer.data, expected) self.assertEqual(serializer.data, expected)
def test_foreign_key_update(self): def test_foreign_key_update(self):
data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'} data = {'url': 'http://testserver/foreignkeysource/1/',
'name': 'source-1',
'target': 'http://testserver/foreignkeytarget/2/'}
instance = ForeignKeySource.objects.get(pk=1) instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request}) serializer = ForeignKeySourceSerializer(instance, data=data,
context={'request': request})
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
serializer.save() serializer.save()
self.assertEqual(serializer.data, data) self.assertEqual(serializer.data, data)
# Ensure source 1 is updated, and everything else is as expected # Ensure source 1 is updated, and everything else is as expected
queryset = ForeignKeySource.objects.all() queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request}) serializer = ForeignKeySourceSerializer(queryset, many=True,
context={'request': request})
expected = [ expected = [
{'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'}, {'url': 'http://testserver/foreignkeysource/1/',
{'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, 'name': 'source-1',
{'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'} 'target': 'http://testserver/foreignkeytarget/2/'},
{'url': 'http://testserver/foreignkeysource/2/',
'name': 'source-2',
'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/3/',
'name': 'source-3',
'target': 'http://testserver/foreignkeytarget/1/'}
] ]
self.assertEqual(serializer.data, expected) self.assertEqual(serializer.data, expected)
def test_foreign_key_update_incorrect_type(self): def test_foreign_key_update_incorrect_type(self):
data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 2} data = {'url': 'http://testserver/foreignkeysource/1/',
'name': 'source-1', 'target': 2}
instance = ForeignKeySource.objects.get(pk=1) instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request}) serializer = ForeignKeySourceSerializer(instance, data=data,
context={'request': request})
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertEqual(serializer.errors, {'target': ['Incorrect type. Expected URL string, received int.']}) self.assertEqual(serializer.errors, {
'target': ['Incorrect type. Expected URL string, received int.']})
def test_reverse_foreign_key_update(self): def test_reverse_foreign_key_update(self):
data = {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']} data = {'url': 'http://testserver/foreignkeytarget/2/',
'name': 'target-2',
'sources': ['http://testserver/foreignkeysource/1/',
'http://testserver/foreignkeysource/3/']}
instance = ForeignKeyTarget.objects.get(pk=2) instance = ForeignKeyTarget.objects.get(pk=2)
serializer = ForeignKeyTargetSerializer(instance, data=data, context={'request': request}) serializer = ForeignKeyTargetSerializer(instance, data=data,
context={'request': request})
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
# We shouldn't have saved anything to the db yet since save # We shouldn't have saved anything to the db yet since save
# hasn't been called. # hasn't been called.
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
new_serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) new_serializer = ForeignKeyTargetSerializer(queryset, many=True,
context={
'request': request})
expected = [ expected = [
{'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']}, {'url': 'http://testserver/foreignkeytarget/1/',
{'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, 'name': 'target-1',
'sources': ['http://testserver/foreignkeysource/1/',
'http://testserver/foreignkeysource/2/',
'http://testserver/foreignkeysource/3/']},
{'url': 'http://testserver/foreignkeytarget/2/',
'name': 'target-2', 'sources': []},
] ]
self.assertEqual(new_serializer.data, expected) self.assertEqual(new_serializer.data, expected)
@ -274,16 +408,25 @@ class HyperlinkedForeignKeyTests(TestCase):
# Ensure target 2 is update, and everything else is as expected # Ensure target 2 is update, and everything else is as expected
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) serializer = ForeignKeyTargetSerializer(queryset, many=True,
context={'request': request})
expected = [ expected = [
{'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']}, {'url': 'http://testserver/foreignkeytarget/1/',
{'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}, 'name': 'target-1',
'sources': ['http://testserver/foreignkeysource/2/']},
{'url': 'http://testserver/foreignkeytarget/2/',
'name': 'target-2',
'sources': ['http://testserver/foreignkeysource/1/',
'http://testserver/foreignkeysource/3/']},
] ]
self.assertEqual(serializer.data, expected) self.assertEqual(serializer.data, expected)
def test_foreign_key_create(self): def test_foreign_key_create(self):
data = {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'} data = {'url': 'http://testserver/foreignkeysource/4/',
serializer = ForeignKeySourceSerializer(data=data, context={'request': request}) 'name': 'source-4',
'target': 'http://testserver/foreignkeytarget/2/'}
serializer = ForeignKeySourceSerializer(data=data,
context={'request': request})
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
obj = serializer.save() obj = serializer.save()
self.assertEqual(serializer.data, data) self.assertEqual(serializer.data, data)
@ -291,18 +434,31 @@ class HyperlinkedForeignKeyTests(TestCase):
# Ensure source 1 is updated, and everything else is as expected # Ensure source 1 is updated, and everything else is as expected
queryset = ForeignKeySource.objects.all() queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request}) serializer = ForeignKeySourceSerializer(queryset, many=True,
context={'request': request})
expected = [ expected = [
{'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/foreignkeysource/1/',
{'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, 'name': 'source-1',
{'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}, 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'}, {'url': 'http://testserver/foreignkeysource/2/',
'name': 'source-2',
'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/3/',
'name': 'source-3',
'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/4/',
'name': 'source-4',
'target': 'http://testserver/foreignkeytarget/2/'},
] ]
self.assertEqual(serializer.data, expected) self.assertEqual(serializer.data, expected)
def test_reverse_foreign_key_create(self): def test_reverse_foreign_key_create(self):
data = {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']} data = {'url': 'http://testserver/foreignkeytarget/3/',
serializer = ForeignKeyTargetSerializer(data=data, context={'request': request}) 'name': 'target-3',
'sources': ['http://testserver/foreignkeysource/1/',
'http://testserver/foreignkeysource/3/']}
serializer = ForeignKeyTargetSerializer(data=data,
context={'request': request})
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
obj = serializer.save() obj = serializer.save()
self.assertEqual(serializer.data, data) self.assertEqual(serializer.data, data)
@ -310,20 +466,30 @@ class HyperlinkedForeignKeyTests(TestCase):
# Ensure target 4 is added, and everything else is as expected # Ensure target 4 is added, and everything else is as expected
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request}) serializer = ForeignKeyTargetSerializer(queryset, many=True,
context={'request': request})
expected = [ expected = [
{'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']}, {'url': 'http://testserver/foreignkeytarget/1/',
{'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []}, 'name': 'target-1',
{'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}, 'sources': ['http://testserver/foreignkeysource/2/']},
{'url': 'http://testserver/foreignkeytarget/2/',
'name': 'target-2', 'sources': []},
{'url': 'http://testserver/foreignkeytarget/3/',
'name': 'target-3',
'sources': ['http://testserver/foreignkeysource/1/',
'http://testserver/foreignkeysource/3/']},
] ]
self.assertEqual(serializer.data, expected) self.assertEqual(serializer.data, expected)
def test_foreign_key_update_with_invalid_null(self): def test_foreign_key_update_with_invalid_null(self):
data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': None} data = {'url': 'http://testserver/foreignkeysource/1/',
'name': 'source-1', 'target': None}
instance = ForeignKeySource.objects.get(pk=1) instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request}) serializer = ForeignKeySourceSerializer(instance, data=data,
context={'request': request})
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertEqual(serializer.errors, {'target': ['This field may not be null.']}) self.assertEqual(serializer.errors,
{'target': ['This field may not be null.']})
@override_settings(ROOT_URLCONF='tests.test_relations_hyperlink') @override_settings(ROOT_URLCONF='tests.test_relations_hyperlink')
@ -334,22 +500,32 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
for idx in range(1, 4): for idx in range(1, 4):
if idx == 3: if idx == 3:
target = None target = None
source = NullableForeignKeySource(name='source-%d' % idx, target=target) source = NullableForeignKeySource(name='source-%d' % idx,
target=target)
source.save() source.save()
def test_foreign_key_retrieve_with_null(self): def test_foreign_key_retrieve_with_null(self):
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) serializer = NullableForeignKeySourceSerializer(
queryset, many=True,
context={'request': request})
expected = [ expected = [
{'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/1/',
{'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, 'name': 'source-1',
{'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/2/',
'name': 'source-2',
'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/3/',
'name': 'source-3', 'target': None},
] ]
self.assertEqual(serializer.data, expected) self.assertEqual(serializer.data, expected)
def test_foreign_key_create_with_valid_null(self): def test_foreign_key_create_with_valid_null(self):
data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} data = {'url': 'http://testserver/nullableforeignkeysource/4/',
serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request}) 'name': 'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data, context={
'request': request})
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
obj = serializer.save() obj = serializer.save()
self.assertEqual(serializer.data, data) self.assertEqual(serializer.data, data)
@ -357,12 +533,20 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
# Ensure source 4 is created, and everything else is as expected # Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) serializer = NullableForeignKeySourceSerializer(
queryset, many=True,
context={'request': request})
expected = [ expected = [
{'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/1/',
{'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, 'name': 'source-1',
{'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} {'url': 'http://testserver/nullableforeignkeysource/2/',
'name': 'source-2',
'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/3/',
'name': 'source-3', 'target': None},
{'url': 'http://testserver/nullableforeignkeysource/4/',
'name': 'source-4', 'target': None}
] ]
self.assertEqual(serializer.data, expected) self.assertEqual(serializer.data, expected)
@ -371,9 +555,13 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
The emptystring should be interpreted as null in the context The emptystring should be interpreted as null in the context
of relationships. of relationships.
""" """
data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': ''} data = {'url': 'http://testserver/nullableforeignkeysource/4/',
expected_data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} 'name': 'source-4', 'target': ''}
serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request}) expected_data = {
'url': 'http://testserver/nullableforeignkeysource/4/',
'name': 'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data, context={
'request': request})
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
obj = serializer.save() obj = serializer.save()
self.assertEqual(serializer.data, expected_data) self.assertEqual(serializer.data, expected_data)
@ -381,30 +569,47 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
# Ensure source 4 is created, and everything else is as expected # Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) serializer = NullableForeignKeySourceSerializer(
queryset, many=True,
context={'request': request})
expected = [ expected = [
{'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'}, {'url': 'http://testserver/nullableforeignkeysource/1/',
{'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, 'name': 'source-1',
{'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None} {'url': 'http://testserver/nullableforeignkeysource/2/',
'name': 'source-2',
'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/3/',
'name': 'source-3', 'target': None},
{'url': 'http://testserver/nullableforeignkeysource/4/',
'name': 'source-4', 'target': None}
] ]
self.assertEqual(serializer.data, expected) self.assertEqual(serializer.data, expected)
def test_foreign_key_update_with_valid_null(self): def test_foreign_key_update_with_valid_null(self):
data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None} data = {'url': 'http://testserver/nullableforeignkeysource/1/',
'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1) instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request}) serializer = NullableForeignKeySourceSerializer(
instance, data=data,
context={'request': request})
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
serializer.save() serializer.save()
self.assertEqual(serializer.data, data) self.assertEqual(serializer.data, data)
# Ensure source 1 is updated, and everything else is as expected # Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) serializer = NullableForeignKeySourceSerializer(
queryset, many=True,
context={'request': request})
expected = [ expected = [
{'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}, {'url': 'http://testserver/nullableforeignkeysource/1/',
{'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, 'name': 'source-1', 'target': None},
{'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, {'url': 'http://testserver/nullableforeignkeysource/2/',
'name': 'source-2',
'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/3/',
'name': 'source-3', 'target': None},
] ]
self.assertEqual(serializer.data, expected) self.assertEqual(serializer.data, expected)
@ -413,21 +618,33 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
The emptystring should be interpreted as null in the context The emptystring should be interpreted as null in the context
of relationships. of relationships.
""" """
data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': ''} data = {'url': 'http://testserver/nullableforeignkeysource/1/',
expected_data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None} 'name': 'source-1', 'target': ''}
expected_data = {
'url': 'http://testserver/nullableforeignkeysource/1/',
'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1) instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request}) serializer = NullableForeignKeySourceSerializer(
instance, data=data,
context={'request': request})
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
serializer.save() serializer.save()
self.assertEqual(serializer.data, expected_data) self.assertEqual(serializer.data, expected_data)
# Ensure source 1 is updated, and everything else is as expected # Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all() queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request}) serializer = NullableForeignKeySourceSerializer(
queryset,
many=True,
context={'request': request})
expected = [ expected = [
{'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}, {'url': 'http://testserver/nullableforeignkeysource/1/',
{'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'}, 'name': 'source-1', 'target': None},
{'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None}, {'url': 'http://testserver/nullableforeignkeysource/2/',
'name': 'source-2',
'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/3/',
'name': 'source-3', 'target': None},
] ]
self.assertEqual(serializer.data, expected) self.assertEqual(serializer.data, expected)
@ -444,9 +661,14 @@ class HyperlinkedNullableOneToOneTests(TestCase):
def test_reverse_foreign_key_retrieve_with_null(self): def test_reverse_foreign_key_retrieve_with_null(self):
queryset = OneToOneTarget.objects.all() queryset = OneToOneTarget.objects.all()
serializer = NullableOneToOneTargetSerializer(queryset, many=True, context={'request': request}) serializer = NullableOneToOneTargetSerializer(
queryset,
many=True,
context={'request': request})
expected = [ expected = [
{'url': 'http://testserver/onetoonetarget/1/', 'name': 'target-1', 'nullable_source': 'http://testserver/nullableonetoonesource/1/'}, {'url': 'http://testserver/onetoonetarget/1/', 'name': 'target-1',
{'url': 'http://testserver/onetoonetarget/2/', 'name': 'target-2', 'nullable_source': None}, 'nullable_source': 'http://testserver/nullableonetoonesource/1/'},
{'url': 'http://testserver/onetoonetarget/2/', 'name': 'target-2',
'nullable_source': None},
] ]
self.assertEqual(serializer.data, expected) self.assertEqual(serializer.data, expected)

View File

@ -248,7 +248,8 @@ class PKForeignKeyTests(TestCase):
instance = ForeignKeySource.objects.get(pk=1) instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data) serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertEqual(serializer.errors, {'target': ['Incorrect type. Expected pk value, received %s.' % six.text_type.__name__]}) self.assertEqual(serializer.errors, {'target': [
'Incorrect type. Expected pk value, received %s.' % six.text_type.__name__]})
def test_reverse_foreign_key_update(self): def test_reverse_foreign_key_update(self):
data = {'id': 2, 'name': 'target-2', 'sources': [1, 3]} data = {'id': 2, 'name': 'target-2', 'sources': [1, 3]}
@ -319,7 +320,8 @@ class PKForeignKeyTests(TestCase):
instance = ForeignKeySource.objects.get(pk=1) instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data) serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertEqual(serializer.errors, {'target': ['This field may not be null.']}) self.assertEqual(serializer.errors,
{'target': ['This field may not be null.']})
def test_foreign_key_with_unsaved(self): def test_foreign_key_with_unsaved(self):
source = ForeignKeySource(name='source-unsaved') source = ForeignKeySource(name='source-unsaved')
@ -345,9 +347,11 @@ class PKForeignKeyTests(TestCase):
Let's say we wanted to fill the non-nullable model field inside Let's say we wanted to fill the non-nullable model field inside
Model.save(), we would make it empty and not required. Model.save(), we would make it empty and not required.
""" """
class ModelSerializer(ForeignKeySourceSerializer): class ModelSerializer(ForeignKeySourceSerializer):
class Meta(ForeignKeySourceSerializer.Meta): class Meta(ForeignKeySourceSerializer.Meta):
extra_kwargs = {'target': {'required': False}} extra_kwargs = {'target': {'required': False}}
serializer = ModelSerializer(data={'name': 'test'}) serializer = ModelSerializer(data={'name': 'test'})
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
self.assertNotIn('target', serializer.validated_data) self.assertNotIn('target', serializer.validated_data)
@ -360,7 +364,8 @@ class PKNullableForeignKeyTests(TestCase):
for idx in range(1, 4): for idx in range(1, 4):
if idx == 3: if idx == 3:
target = None target = None
source = NullableForeignKeySource(name='source-%d' % idx, target=target) source = NullableForeignKeySource(name='source-%d' % idx,
target=target)
source.save() source.save()
def test_foreign_key_retrieve_with_null(self): def test_foreign_key_retrieve_with_null(self):

View File

@ -73,7 +73,8 @@ class SlugForeignKeyTests(TestCase):
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True) serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [ expected = [
{'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, {'id': 1, 'name': 'target-1',
'sources': ['source-1', 'source-2', 'source-3']},
{'id': 2, 'name': 'target-2', 'sources': []}, {'id': 2, 'name': 'target-2', 'sources': []},
] ]
self.assertEqual(serializer.data, expected) self.assertEqual(serializer.data, expected)
@ -107,10 +108,12 @@ class SlugForeignKeyTests(TestCase):
instance = ForeignKeySource.objects.get(pk=1) instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data) serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertEqual(serializer.errors, {'target': ['Object with name=123 does not exist.']}) self.assertEqual(serializer.errors,
{'target': ['Object with name=123 does not exist.']})
def test_reverse_foreign_key_update(self): def test_reverse_foreign_key_update(self):
data = {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']} data = {'id': 2, 'name': 'target-2',
'sources': ['source-1', 'source-3']}
instance = ForeignKeyTarget.objects.get(pk=2) instance = ForeignKeyTarget.objects.get(pk=2)
serializer = ForeignKeyTargetSerializer(instance, data=data) serializer = ForeignKeyTargetSerializer(instance, data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
@ -119,7 +122,8 @@ class SlugForeignKeyTests(TestCase):
queryset = ForeignKeyTarget.objects.all() queryset = ForeignKeyTarget.objects.all()
new_serializer = ForeignKeyTargetSerializer(queryset, many=True) new_serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [ expected = [
{'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']}, {'id': 1, 'name': 'target-1',
'sources': ['source-1', 'source-2', 'source-3']},
{'id': 2, 'name': 'target-2', 'sources': []}, {'id': 2, 'name': 'target-2', 'sources': []},
] ]
self.assertEqual(new_serializer.data, expected) self.assertEqual(new_serializer.data, expected)
@ -157,7 +161,8 @@ class SlugForeignKeyTests(TestCase):
self.assertEqual(serializer.data, expected) self.assertEqual(serializer.data, expected)
def test_reverse_foreign_key_create(self): def test_reverse_foreign_key_create(self):
data = {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']} data = {'id': 3, 'name': 'target-3',
'sources': ['source-1', 'source-3']}
serializer = ForeignKeyTargetSerializer(data=data) serializer = ForeignKeyTargetSerializer(data=data)
self.assertTrue(serializer.is_valid()) self.assertTrue(serializer.is_valid())
obj = serializer.save() obj = serializer.save()
@ -179,7 +184,8 @@ class SlugForeignKeyTests(TestCase):
instance = ForeignKeySource.objects.get(pk=1) instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data) serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertFalse(serializer.is_valid()) self.assertFalse(serializer.is_valid())
self.assertEqual(serializer.errors, {'target': ['This field may not be null.']}) self.assertEqual(serializer.errors,
{'target': ['This field may not be null.']})
class SlugNullableForeignKeyTests(TestCase): class SlugNullableForeignKeyTests(TestCase):
@ -189,7 +195,8 @@ class SlugNullableForeignKeyTests(TestCase):
for idx in range(1, 4): for idx in range(1, 4):
if idx == 3: if idx == 3:
target = None target = None
source = NullableForeignKeySource(name='source-%d' % idx, target=target) source = NullableForeignKeySource(name='source-%d' % idx,
target=target)
source.save() source.save()
def test_foreign_key_retrieve_with_null(self): def test_foreign_key_retrieve_with_null(self):

View File

@ -3,7 +3,7 @@ from __future__ import unicode_literals
import json import json
import re import re
from collections import MutableMapping, OrderedDict from collections import MutableMapping
from django.conf.urls import include, url from django.conf.urls import include, url
from django.core.cache import cache from django.core.cache import cache
@ -13,11 +13,15 @@ from django.utils import six
from django.utils.safestring import SafeText from django.utils.safestring import SafeText
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework import permissions, serializers, status from rest_framework import permissions, status
from rest_framework.fields import (
CharField, ChoiceField, HiddenField, MultipleChoiceField, OrderedDict
)
from rest_framework.renderers import ( from rest_framework.renderers import (
BaseRenderer, BrowsableAPIRenderer, HTMLFormRenderer, JSONRenderer BaseRenderer, BrowsableAPIRenderer, HTMLFormRenderer, JSONRenderer
) )
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.serializers import Serializer
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
from rest_framework.views import APIView from rest_framework.views import APIView
@ -92,7 +96,7 @@ class EmptyGETView(APIView):
class HTMLView(APIView): class HTMLView(APIView):
renderer_classes = (BrowsableAPIRenderer, ) renderer_classes = (BrowsableAPIRenderer,)
def get(self, request, **kwargs): def get(self, request, **kwargs):
return Response('text') return Response('text')
@ -104,11 +108,14 @@ class HTMLView1(APIView):
def get(self, request, **kwargs): def get(self, request, **kwargs):
return Response('text') return Response('text')
urlpatterns = [ urlpatterns = [
url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB])), url(r'^.*\.(?P<format>.+)$',
MockView.as_view(renderer_classes=[RendererA, RendererB])),
url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB])), url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB])),
url(r'^cache$', MockGETView.as_view()), url(r'^cache$', MockGETView.as_view()),
url(r'^parseerror$', MockPOSTView.as_view(renderer_classes=[JSONRenderer, BrowsableAPIRenderer])), url(r'^parseerror$', MockPOSTView.as_view(
renderer_classes=[JSONRenderer, BrowsableAPIRenderer])),
url(r'^html$', HTMLView.as_view()), url(r'^html$', HTMLView.as_view()),
url(r'^html1$', HTMLView1.as_view()), url(r'^html1$', HTMLView1.as_view()),
url(r'^empty$', EmptyGETView.as_view()), url(r'^empty$', EmptyGETView.as_view()),
@ -153,10 +160,13 @@ class RendererEndToEndTests(TestCase):
""" """
End-to-end testing of renderers using an RendererMixin on a generic view. End-to-end testing of renderers using an RendererMixin on a generic view.
""" """
def test_default_renderer_serializes_content(self): def test_default_renderer_serializes_content(self):
"""If the Accept header is not set the default renderer should serialize the response.""" """If the Accept header is not set the default renderer should
serialize the response."""
resp = self.client.get('/') resp = self.client.get('/')
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') self.assertEqual(resp['Content-Type'],
RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS) self.assertEqual(resp.status_code, DUMMYSTATUS)
@ -164,34 +174,42 @@ class RendererEndToEndTests(TestCase):
"""No response must be included in HEAD requests.""" """No response must be included in HEAD requests."""
resp = self.client.head('/') resp = self.client.head('/')
self.assertEqual(resp.status_code, DUMMYSTATUS) self.assertEqual(resp.status_code, DUMMYSTATUS)
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') self.assertEqual(resp['Content-Type'],
RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp.content, six.b('')) self.assertEqual(resp.content, six.b(''))
def test_default_renderer_serializes_content_on_accept_any(self): def test_default_renderer_serializes_content_on_accept_any(self):
"""If the Accept header is set to */* the default renderer should serialize the response.""" """If the Accept header is set to */* the default renderer should
serialize the response."""
resp = self.client.get('/', HTTP_ACCEPT='*/*') resp = self.client.get('/', HTTP_ACCEPT='*/*')
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') self.assertEqual(resp['Content-Type'],
RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS) self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_serializes_content_default_case(self): def test_specified_renderer_serializes_content_default_case(self):
"""If the Accept header is set the specified renderer should serialize the response. """If the Accept header is set the specified renderer should serialize
(In this case we check that works for the default renderer)""" the response. (In this case we check that works for the default
renderer)"""
resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type) resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type)
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') self.assertEqual(resp['Content-Type'],
RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS) self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_serializes_content_non_default_case(self): def test_specified_renderer_serializes_content_non_default_case(self):
"""If the Accept header is set the specified renderer should serialize the response. """If the Accept header is set the specified renderer should serialize
(In this case we check that works for a non-default renderer)""" the response. (In this case we check that works for a non-default
renderer)"""
resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type) resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type)
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') self.assertEqual(resp['Content-Type'],
RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS) self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_unsatisfiable_accept_header_on_request_returns_406_status(self): def test_unsatisfiable_accept_header_on_request_returns_406_status(self):
"""If the Accept header is unsatisfiable we should return a 406 Not Acceptable response.""" """If the Accept header is unsatisfiable we should return a
406 Not Acceptable response."""
resp = self.client.get('/', HTTP_ACCEPT='foo/bar') resp = self.client.get('/', HTTP_ACCEPT='foo/bar')
self.assertEqual(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE) self.assertEqual(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE)
@ -203,34 +221,41 @@ class RendererEndToEndTests(TestCase):
RendererB.format RendererB.format
) )
resp = self.client.get('/' + param) resp = self.client.get('/' + param)
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') self.assertEqual(resp['Content-Type'],
RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS) self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_serializes_content_on_format_kwargs(self): def test_specified_renderer_serializes_content_on_format_kwargs(self):
"""If a 'format' keyword arg is specified, the renderer with the matching """If a 'format' keyword arg is specified, the renderer with the
format attribute should serialize the response.""" matching format attribute should serialize the response."""
resp = self.client.get('/something.formatb') resp = self.client.get('/something.formatb')
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') self.assertEqual(resp['Content-Type'],
RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS) self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_is_used_on_format_query_with_matching_accept(self): def test_specified_renderer_is_used_on_format_query_with_matching_accept(
self):
"""If both a 'format' query and a matching Accept header specified, """If both a 'format' query and a matching Accept header specified,
the renderer with the matching format attribute should serialize the response.""" the renderer with the matching format attribute should serialize the
response."""
param = '?%s=%s' % ( param = '?%s=%s' % (
api_settings.URL_FORMAT_OVERRIDE, api_settings.URL_FORMAT_OVERRIDE,
RendererB.format RendererB.format
) )
resp = self.client.get('/' + param, resp = self.client.get('/' + param,
HTTP_ACCEPT=RendererB.media_type) HTTP_ACCEPT=RendererB.media_type)
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') self.assertEqual(resp['Content-Type'],
RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS) self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_parse_error_renderers_browsable_api(self): def test_parse_error_renderers_browsable_api(self):
"""Invalid data should still render the browsable API correctly.""" """Invalid data should still render the browsable API correctly."""
resp = self.client.post('/parseerror', data='foobar', content_type='application/json', HTTP_ACCEPT='text/html') resp = self.client.post('/parseerror', data='foobar',
content_type='application/json',
HTTP_ACCEPT='text/html')
self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST)
@ -358,18 +383,22 @@ class JSONRendererTests(TestCase):
obj = {'foo': ['bar', 'baz']} obj = {'foo': ['bar', 'baz']}
renderer = JSONRenderer() renderer = JSONRenderer()
content = renderer.render(obj, 'application/json; indent=2') content = renderer.render(obj, 'application/json; indent=2')
self.assertEqual(strip_trailing_whitespace(content.decode('utf-8')), _indented_repr) self.assertEqual(strip_trailing_whitespace(content.decode('utf-8')),
_indented_repr)
class UnicodeJSONRendererTests(TestCase): class UnicodeJSONRendererTests(TestCase):
""" """
Tests specific for the Unicode JSON Renderer Tests specific for the Unicode JSON Renderer
""" """
def test_proper_encoding(self): def test_proper_encoding(self):
obj = {'countries': ['United Kingdom', 'France', 'España']} obj = {'countries': ['United Kingdom', 'France', 'España']}
renderer = JSONRenderer() renderer = JSONRenderer()
content = renderer.render(obj, 'application/json') content = renderer.render(obj, 'application/json')
self.assertEqual(content, '{"countries":["United Kingdom","France","España"]}'.encode('utf-8')) self.assertEqual(content,
'{"countries":["United Kingdom","France","España"]}'.
encode('utf-8'))
def test_u2028_u2029(self): def test_u2028_u2029(self):
# The \u2028 and \u2029 characters should be escaped, # The \u2028 and \u2029 characters should be escaped,
@ -378,20 +407,26 @@ class UnicodeJSONRendererTests(TestCase):
obj = {'should_escape': '\u2028\u2029'} obj = {'should_escape': '\u2028\u2029'}
renderer = JSONRenderer() renderer = JSONRenderer()
content = renderer.render(obj, 'application/json') content = renderer.render(obj, 'application/json')
self.assertEqual(content, '{"should_escape":"\\u2028\\u2029"}'.encode('utf-8')) self.assertEqual(content,
'{"should_escape":"\\u2028\\u2029"}'.encode('utf-8'))
class AsciiJSONRendererTests(TestCase): class AsciiJSONRendererTests(TestCase):
""" """
Tests specific for the Unicode JSON Renderer Tests specific for the Unicode JSON Renderer
""" """
def test_proper_encoding(self): def test_proper_encoding(self):
class AsciiJSONRenderer(JSONRenderer): class AsciiJSONRenderer(JSONRenderer):
ensure_ascii = True ensure_ascii = True
obj = {'countries': ['United Kingdom', 'France', 'España']} obj = {'countries': ['United Kingdom', 'France', 'España']}
renderer = AsciiJSONRenderer() renderer = AsciiJSONRenderer()
content = renderer.render(obj, 'application/json') content = renderer.render(obj, 'application/json')
self.assertEqual(content, '{"countries":["United Kingdom","France","Espa\\u00f1a"]}'.encode('utf-8')) self.assertEqual(content,
'{"countries":["United Kingdom","France",'
'"Espa\\u00f1a"]}'.encode(
'utf-8'))
# Tests for caching issue, #346 # Tests for caching issue, #346
@ -400,6 +435,7 @@ class CacheRenderTest(TestCase):
""" """
Tests specific to caching responses Tests specific to caching responses
""" """
def test_head_caching(self): def test_head_caching(self):
""" """
Test caching of HEAD requests Test caching of HEAD requests
@ -447,8 +483,8 @@ class TestJSONIndentationStyles:
class TestHiddenFieldHTMLFormRenderer(TestCase): class TestHiddenFieldHTMLFormRenderer(TestCase):
def test_hidden_field_rendering(self): def test_hidden_field_rendering(self):
class TestSerializer(serializers.Serializer): class TestSerializer(Serializer):
published = serializers.HiddenField(default=True) published = HiddenField(default=True)
serializer = TestSerializer(data={}) serializer = TestSerializer(data={})
serializer.is_valid() serializer.is_valid()
@ -460,8 +496,8 @@ class TestHiddenFieldHTMLFormRenderer(TestCase):
class TestHTMLFormRenderer(TestCase): class TestHTMLFormRenderer(TestCase):
def setUp(self): def setUp(self):
class TestSerializer(serializers.Serializer): class TestSerializer(Serializer):
test_field = serializers.CharField() test_field = CharField()
self.renderer = HTMLFormRenderer() self.renderer = HTMLFormRenderer()
self.serializer = TestSerializer(data={}) self.serializer = TestSerializer(data={})
@ -491,9 +527,9 @@ class TestChoiceFieldHTMLFormRenderer(TestCase):
def setUp(self): def setUp(self):
choices = ((1, 'Option1'), (2, 'Option2'), (12, 'Option12')) choices = ((1, 'Option1'), (2, 'Option2'), (12, 'Option12'))
class TestSerializer(serializers.Serializer): class TestSerializer(Serializer):
test_field = serializers.ChoiceField(choices=choices, test_field = ChoiceField(choices=choices,
initial=2) initial=2)
self.TestSerializer = TestSerializer self.TestSerializer = TestSerializer
self.renderer = HTMLFormRenderer() self.renderer = HTMLFormRenderer()
@ -535,8 +571,8 @@ class TestMultipleChoiceFieldHTMLFormRenderer(TestCase):
choices = (('1', 'Option1'), ('2', 'Option2'), ('12', 'Option12'), choices = (('1', 'Option1'), ('2', 'Option2'), ('12', 'Option12'),
('}', 'OptionBrace')) ('}', 'OptionBrace'))
class TestSerializer(serializers.Serializer): class TestSerializer(Serializer):
test_field = serializers.MultipleChoiceField(choices=choices) test_field = MultipleChoiceField(choices=choices)
serializer = TestSerializer(data={'test_field': ['12']}) serializer = TestSerializer(data={'test_field': ['12']})
serializer.is_valid() serializer.is_valid()
@ -554,8 +590,8 @@ class TestMultipleChoiceFieldHTMLFormRenderer(TestCase):
def test_render_selected_option_with_integer_option_ids(self): def test_render_selected_option_with_integer_option_ids(self):
choices = ((1, 'Option1'), (2, 'Option2'), (12, 'Option12')) choices = ((1, 'Option1'), (2, 'Option2'), (12, 'Option12'))
class TestSerializer(serializers.Serializer): class TestSerializer(Serializer):
test_field = serializers.MultipleChoiceField(choices=choices) test_field = MultipleChoiceField(choices=choices)
serializer = TestSerializer(data={'test_field': ['12']}) serializer = TestSerializer(data={'test_field': ['12']})
serializer.is_valid() serializer.is_valid()

View File

@ -32,6 +32,7 @@ class MockJsonRenderer(BaseRenderer):
class MockTextMediaRenderer(BaseRenderer): class MockTextMediaRenderer(BaseRenderer):
media_type = 'text/html' media_type = 'text/html'
DUMMYSTATUS = status.HTTP_200_OK DUMMYSTATUS = status.HTTP_200_OK
DUMMYCONTENT = 'dummycontent' DUMMYCONTENT = 'dummycontent'
@ -77,7 +78,8 @@ class MockViewSettingContentType(APIView):
renderer_classes = (RendererA, RendererB, RendererC) renderer_classes = (RendererA, RendererB, RendererC)
def get(self, request, **kwargs): def get(self, request, **kwargs):
return Response(DUMMYCONTENT, status=DUMMYSTATUS, content_type='setbyview') return Response(DUMMYCONTENT, status=DUMMYSTATUS,
content_type='setbyview')
class JSONView(APIView): class JSONView(APIView):
@ -89,7 +91,7 @@ class JSONView(APIView):
class HTMLView(APIView): class HTMLView(APIView):
renderer_classes = (BrowsableAPIRenderer, ) renderer_classes = (BrowsableAPIRenderer,)
def get(self, request, **kwargs): def get(self, request, **kwargs):
return Response('text') return Response('text')
@ -117,17 +119,20 @@ class HTMLNewModelView(generics.ListCreateAPIView):
new_model_viewset_router = routers.DefaultRouter() new_model_viewset_router = routers.DefaultRouter()
new_model_viewset_router.register(r'', HTMLNewModelViewSet) new_model_viewset_router.register(r'', HTMLNewModelViewSet)
urlpatterns = [ urlpatterns = [
url(r'^setbyview$', MockViewSettingContentType.as_view(renderer_classes=[RendererA, RendererB, RendererC])), url(r'^setbyview$', MockViewSettingContentType.as_view(
url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])), renderer_classes=[RendererA, RendererB, RendererC])),
url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])), url(r'^.*\.(?P<format>.+)$',
MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])),
url(r'^$',
MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])),
url(r'^html$', HTMLView.as_view()), url(r'^html$', HTMLView.as_view()),
url(r'^json$', JSONView.as_view()), url(r'^json$', JSONView.as_view()),
url(r'^html1$', HTMLView1.as_view()), url(r'^html1$', HTMLView1.as_view()),
url(r'^html_new_model$', HTMLNewModelView.as_view()), url(r'^html_new_model$', HTMLNewModelView.as_view()),
url(r'^html_new_model_viewset', include(new_model_viewset_router.urls)), url(r'^html_new_model_viewset', include(new_model_viewset_router.urls)),
url(r'^restframework', include('rest_framework.urls', namespace='rest_framework')) url(r'^restframework',
include('rest_framework.urls', namespace='rest_framework'))
] ]
@ -137,10 +142,13 @@ class RendererIntegrationTests(TestCase):
""" """
End-to-end testing of renderers using an ResponseMixin on a generic view. End-to-end testing of renderers using an ResponseMixin on a generic view.
""" """
def test_default_renderer_serializes_content(self): def test_default_renderer_serializes_content(self):
"""If the Accept header is not set the default renderer should serialize the response.""" """If the Accept header is not set the default renderer should
serialize the response."""
resp = self.client.get('/') resp = self.client.get('/')
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') self.assertEqual(resp['Content-Type'],
RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS) self.assertEqual(resp.status_code, DUMMYSTATUS)
@ -148,29 +156,36 @@ class RendererIntegrationTests(TestCase):
"""No response must be included in HEAD requests.""" """No response must be included in HEAD requests."""
resp = self.client.head('/') resp = self.client.head('/')
self.assertEqual(resp.status_code, DUMMYSTATUS) self.assertEqual(resp.status_code, DUMMYSTATUS)
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') self.assertEqual(resp['Content-Type'],
RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp.content, six.b('')) self.assertEqual(resp.content, six.b(''))
def test_default_renderer_serializes_content_on_accept_any(self): def test_default_renderer_serializes_content_on_accept_any(self):
"""If the Accept header is set to */* the default renderer should serialize the response.""" """If the Accept header is set to */* the default renderer should
serialize the response."""
resp = self.client.get('/', HTTP_ACCEPT='*/*') resp = self.client.get('/', HTTP_ACCEPT='*/*')
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') self.assertEqual(resp['Content-Type'],
RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS) self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_serializes_content_default_case(self): def test_specified_renderer_serializes_content_default_case(self):
"""If the Accept header is set the specified renderer should serialize the response. """If the Accept header is set the specified renderer should serialize
the response.
(In this case we check that works for the default renderer)""" (In this case we check that works for the default renderer)"""
resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type) resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type)
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8') self.assertEqual(resp['Content-Type'],
RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS) self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_serializes_content_non_default_case(self): def test_specified_renderer_serializes_content_non_default_case(self):
"""If the Accept header is set the specified renderer should serialize the response. """If the Accept header is set the specified renderer should serialize
the response.
(In this case we check that works for a non-default renderer)""" (In this case we check that works for a non-default renderer)"""
resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type) resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type)
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') self.assertEqual(resp['Content-Type'],
RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS) self.assertEqual(resp.status_code, DUMMYSTATUS)
@ -178,24 +193,29 @@ class RendererIntegrationTests(TestCase):
"""If a 'format' query is specified, the renderer with the matching """If a 'format' query is specified, the renderer with the matching
format attribute should serialize the response.""" format attribute should serialize the response."""
resp = self.client.get('/?format=%s' % RendererB.format) resp = self.client.get('/?format=%s' % RendererB.format)
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') self.assertEqual(resp['Content-Type'],
RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS) self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_serializes_content_on_format_kwargs(self): def test_specified_renderer_serializes_content_on_format_kwargs(self):
"""If a 'format' keyword arg is specified, the renderer with the matching """If a 'format' keyword arg is specified, the renderer with the
format attribute should serialize the response.""" matching format attribute should serialize the response."""
resp = self.client.get('/something.formatb') resp = self.client.get('/something.formatb')
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') self.assertEqual(resp['Content-Type'],
RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS) self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_is_used_on_format_query_with_matching_accept(self): def test_specified_renderer_is_used_on_format_query_with_matching_accept(
self):
"""If both a 'format' query and a matching Accept header specified, """If both a 'format' query and a matching Accept header specified,
the renderer with the matching format attribute should serialize the response.""" the renderer with the matching format attribute should serialize the
response."""
resp = self.client.get('/?format=%s' % RendererB.format, resp = self.client.get('/?format=%s' % RendererB.format,
HTTP_ACCEPT=RendererB.media_type) HTTP_ACCEPT=RendererB.media_type)
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8') self.assertEqual(resp['Content-Type'],
RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT)) self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS) self.assertEqual(resp.status_code, DUMMYSTATUS)
@ -203,12 +223,14 @@ class RendererIntegrationTests(TestCase):
@override_settings(ROOT_URLCONF='tests.test_response') @override_settings(ROOT_URLCONF='tests.test_response')
class UnsupportedMediaTypeTests(TestCase): class UnsupportedMediaTypeTests(TestCase):
def test_should_allow_posting_json(self): def test_should_allow_posting_json(self):
response = self.client.post('/json', data='{"test": 123}', content_type='application/json') response = self.client.post('/json', data='{"test": 123}',
content_type='application/json')
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
def test_should_not_allow_posting_xml(self): def test_should_not_allow_posting_xml(self):
response = self.client.post('/json', data='<test>123</test>', content_type='application/xml') response = self.client.post('/json', data='<test>123</test>',
content_type='application/xml')
self.assertEqual(response.status_code, 415) self.assertEqual(response.status_code, 415)
@ -223,6 +245,7 @@ class Issue122Tests(TestCase):
""" """
Tests that covers #122. Tests that covers #122.
""" """
def test_only_html_renderer(self): def test_only_html_renderer(self):
""" """
Test if no infinite recursion occurs. Test if no infinite recursion occurs.
@ -241,6 +264,7 @@ class Issue467Tests(TestCase):
""" """
Tests for #467 Tests for #467
""" """
def test_form_has_label_and_help_text(self): def test_form_has_label_and_help_text(self):
resp = self.client.get('/html_new_model') resp = self.client.get('/html_new_model')
self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
@ -253,6 +277,7 @@ class Issue807Tests(TestCase):
""" """
Covers #807 Covers #807
""" """
def test_does_not_append_charset_by_default(self): def test_does_not_append_charset_by_default(self):
""" """
Renderers don't include a charset unless set explicitly. Renderers don't include a charset unless set explicitly.
@ -269,7 +294,8 @@ class Issue807Tests(TestCase):
""" """
headers = {"HTTP_ACCEPT": RendererC.media_type} headers = {"HTTP_ACCEPT": RendererC.media_type}
resp = self.client.get('/', **headers) resp = self.client.get('/', **headers)
expected = "{0}; charset={1}".format(RendererC.media_type, RendererC.charset) expected = "{0}; charset={1}".format(RendererC.media_type,
RendererC.charset)
self.assertEqual(expected, resp['Content-Type']) self.assertEqual(expected, resp['Content-Type'])
def test_content_type_set_explicitly_on_response(self): def test_content_type_set_explicitly_on_response(self):

View File

@ -6,8 +6,11 @@ import re
import pytest import pytest
from rest_framework import serializers
from rest_framework.compat import unicode_repr from rest_framework.compat import unicode_repr
from rest_framework.fields import (
CharField, IntegerField, RegexField, ValidationError
)
from rest_framework.serializers import BaseSerializer, Serializer
from .utils import MockObject from .utils import MockObject
@ -17,9 +20,10 @@ from .utils import MockObject
class TestSerializer: class TestSerializer:
def setup(self): def setup(self):
class ExampleSerializer(serializers.Serializer): class ExampleSerializer(Serializer):
char = serializers.CharField() char = CharField()
integer = serializers.IntegerField() integer = IntegerField()
self.Serializer = ExampleSerializer self.Serializer = ExampleSerializer
def test_valid_serializer(self): def test_valid_serializer(self):
@ -47,6 +51,7 @@ class TestSerializer:
def test_missing_attribute_during_serialization(self): def test_missing_attribute_during_serialization(self):
class MissingAttributes: class MissingAttributes:
pass pass
instance = MissingAttributes() instance = MissingAttributes()
serializer = self.Serializer(instance) serializer = self.Serializer(instance)
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
@ -55,6 +60,7 @@ class TestSerializer:
def test_data_access_before_save_raises_error(self): def test_data_access_before_save_raises_error(self):
def create(validated_data): def create(validated_data):
return validated_data return validated_data
serializer = self.Serializer(data={'char': 'abc', 'integer': 123}) serializer = self.Serializer(data={'char': 'abc', 'integer': 123})
serializer.create = create serializer.create = create
assert serializer.is_valid() assert serializer.is_valid()
@ -71,24 +77,24 @@ class TestSerializer:
class TestValidateMethod: class TestValidateMethod:
def test_non_field_error_validate_method(self): def test_non_field_error_validate_method(self):
class ExampleSerializer(serializers.Serializer): class ExampleSerializer(Serializer):
char = serializers.CharField() char = CharField()
integer = serializers.IntegerField() integer = IntegerField()
def validate(self, attrs): def validate(self, attrs):
raise serializers.ValidationError('Non field error') raise ValidationError('Non field error')
serializer = ExampleSerializer(data={'char': 'abc', 'integer': 123}) serializer = ExampleSerializer(data={'char': 'abc', 'integer': 123})
assert not serializer.is_valid() assert not serializer.is_valid()
assert serializer.errors == {'non_field_errors': ['Non field error']} assert serializer.errors == {'non_field_errors': ['Non field error']}
def test_field_error_validate_method(self): def test_field_error_validate_method(self):
class ExampleSerializer(serializers.Serializer): class ExampleSerializer(Serializer):
char = serializers.CharField() char = CharField()
integer = serializers.IntegerField() integer = IntegerField()
def validate(self, attrs): def validate(self, attrs):
raise serializers.ValidationError({'char': 'Field error'}) raise ValidationError({'char': 'Field error'})
serializer = ExampleSerializer(data={'char': 'abc', 'integer': 123}) serializer = ExampleSerializer(data={'char': 'abc', 'integer': 123})
assert not serializer.is_valid() assert not serializer.is_valid()
@ -97,7 +103,7 @@ class TestValidateMethod:
class TestBaseSerializer: class TestBaseSerializer:
def setup(self): def setup(self):
class ExampleSerializer(serializers.BaseSerializer): class ExampleSerializer(BaseSerializer):
def to_representation(self, obj): def to_representation(self, obj):
return { return {
'id': obj['id'], 'id': obj['id'],
@ -167,15 +173,15 @@ class TestStarredSource:
} }
def setup(self): def setup(self):
class NestedSerializer1(serializers.Serializer): class NestedSerializer1(Serializer):
a = serializers.IntegerField() a = IntegerField()
b = serializers.IntegerField() b = IntegerField()
class NestedSerializer2(serializers.Serializer): class NestedSerializer2(Serializer):
c = serializers.IntegerField() c = IntegerField()
d = serializers.IntegerField() d = IntegerField()
class TestSerializer(serializers.Serializer): class TestSerializer(Serializer):
nested1 = NestedSerializer1(source='*') nested1 = NestedSerializer1(source='*')
nested2 = NestedSerializer2(source='*') nested2 = NestedSerializer2(source='*')
@ -205,8 +211,8 @@ class TestStarredSource:
class TestIncorrectlyConfigured: class TestIncorrectlyConfigured:
def test_incorrect_field_name(self): def test_incorrect_field_name(self):
class ExampleSerializer(serializers.Serializer): class ExampleSerializer(Serializer):
incorrect_name = serializers.IntegerField() incorrect_name = IntegerField()
class ExampleObject: class ExampleObject:
def __init__(self): def __init__(self):
@ -226,8 +232,8 @@ class TestIncorrectlyConfigured:
class TestUnicodeRepr: class TestUnicodeRepr:
def test_unicode_repr(self): def test_unicode_repr(self):
class ExampleSerializer(serializers.Serializer): class ExampleSerializer(Serializer):
example = serializers.CharField() example = CharField()
class ExampleObject: class ExampleObject:
def __init__(self): def __init__(self):
@ -246,9 +252,10 @@ class TestNotRequiredOutput:
""" """
'required=False' should allow a dictionary key to be missing in output. 'required=False' should allow a dictionary key to be missing in output.
""" """
class ExampleSerializer(serializers.Serializer):
omitted = serializers.CharField(required=False) class ExampleSerializer(Serializer):
included = serializers.CharField() omitted = CharField(required=False)
included = CharField()
serializer = ExampleSerializer(data={'included': 'abc'}) serializer = ExampleSerializer(data={'included': 'abc'})
serializer.is_valid() serializer.is_valid()
@ -258,9 +265,10 @@ class TestNotRequiredOutput:
""" """
'required=False' should allow an object attribute to be missing in output. 'required=False' should allow an object attribute to be missing in output.
""" """
class ExampleSerializer(serializers.Serializer):
omitted = serializers.CharField(required=False) class ExampleSerializer(Serializer):
included = serializers.CharField() omitted = CharField(required=False)
included = CharField()
def create(self, validated_data): def create(self, validated_data):
return MockObject(**validated_data) return MockObject(**validated_data)
@ -277,9 +285,10 @@ class TestNotRequiredOutput:
We need to handle this as the field will have an implicit We need to handle this as the field will have an implicit
'required=False', but it should still have a value. 'required=False', but it should still have a value.
""" """
class ExampleSerializer(serializers.Serializer):
omitted = serializers.CharField(default='abc') class ExampleSerializer(Serializer):
included = serializers.CharField() omitted = CharField(default='abc')
included = CharField()
serializer = ExampleSerializer({'included': 'abc'}) serializer = ExampleSerializer({'included': 'abc'})
with pytest.raises(KeyError): with pytest.raises(KeyError):
@ -292,9 +301,10 @@ class TestNotRequiredOutput:
We need to handle this as the field will have an implicit We need to handle this as the field will have an implicit
'required=False', but it should still have a value. 'required=False', but it should still have a value.
""" """
class ExampleSerializer(serializers.Serializer):
omitted = serializers.CharField(default='abc') class ExampleSerializer(Serializer):
included = serializers.CharField() omitted = CharField(default='abc')
included = CharField()
instance = MockObject(included='abc') instance = MockObject(included='abc')
serializer = ExampleSerializer(instance) serializer = ExampleSerializer(instance)
@ -308,9 +318,10 @@ class TestCacheSerializerData:
Caching serializer data with pickle will drop the serializer info, Caching serializer data with pickle will drop the serializer info,
but does preserve the data itself. but does preserve the data itself.
""" """
class ExampleSerializer(serializers.Serializer):
field1 = serializers.CharField() class ExampleSerializer(Serializer):
field2 = serializers.CharField() field1 = CharField()
field2 = CharField()
serializer = ExampleSerializer({'field1': 'a', 'field2': 'b'}) serializer = ExampleSerializer({'field1': 'a', 'field2': 'b'})
pickled = pickle.dumps(serializer.data) pickled = pickle.dumps(serializer.data)
@ -320,9 +331,10 @@ class TestCacheSerializerData:
class TestDefaultInclusions: class TestDefaultInclusions:
def setup(self): def setup(self):
class ExampleSerializer(serializers.Serializer): class ExampleSerializer(Serializer):
char = serializers.CharField(read_only=True, default='abc') char = CharField(read_only=True, default='abc')
integer = serializers.IntegerField() integer = IntegerField()
self.Serializer = ExampleSerializer self.Serializer = ExampleSerializer
def test_default_should_included_on_create(self): def test_default_should_included_on_create(self):
@ -340,7 +352,8 @@ class TestDefaultInclusions:
def test_default_should_not_be_included_on_partial_update(self): def test_default_should_not_be_included_on_partial_update(self):
instance = MockObject(char='def', integer=123) instance = MockObject(char='def', integer=123)
serializer = self.Serializer(instance, data={'integer': 456}, partial=True) serializer = self.Serializer(instance, data={'integer': 456},
partial=True)
assert serializer.is_valid() assert serializer.is_valid()
assert serializer.validated_data == {'integer': 456} assert serializer.validated_data == {'integer': 456}
assert serializer.errors == {} assert serializer.errors == {}
@ -348,8 +361,9 @@ class TestDefaultInclusions:
class TestSerializerValidationWithCompiledRegexField: class TestSerializerValidationWithCompiledRegexField:
def setup(self): def setup(self):
class ExampleSerializer(serializers.Serializer): class ExampleSerializer(Serializer):
name = serializers.RegexField(re.compile(r'\d'), required=True) name = RegexField(re.compile(r'\d'), required=True)
self.Serializer = ExampleSerializer self.Serializer = ExampleSerializer
def test_validation_success(self): def test_validation_success(self):

View File

@ -1,12 +1,17 @@
from django.utils.datastructures import MultiValueDict from django.utils.datastructures import MultiValueDict
from rest_framework import serializers from rest_framework.fields import (
BooleanField, IntegerField, ListField, MultipleChoiceField,
ValidationError
)
from rest_framework.serializers import ListSerializer, Serializer
class BasicObject: class BasicObject:
""" """
A mock object for testing serializer save behavior. A mock object for testing serializer save behavior.
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
self._data = kwargs self._data = kwargs
for key, value in kwargs.items(): for key, value in kwargs.items():
@ -28,8 +33,9 @@ class TestListSerializer:
""" """
def setup(self): def setup(self):
class IntegerListSerializer(serializers.ListSerializer): class IntegerListSerializer(ListSerializer):
child = serializers.IntegerField() child = IntegerField()
self.Serializer = IntegerListSerializer self.Serializer = IntegerListSerializer
def test_validate(self): def test_validate(self):
@ -59,14 +65,14 @@ class TestListSerializerContainingNestedSerializer:
""" """
def setup(self): def setup(self):
class TestSerializer(serializers.Serializer): class TestSerializer(Serializer):
integer = serializers.IntegerField() integer = IntegerField()
boolean = serializers.BooleanField() boolean = BooleanField()
def create(self, validated_data): def create(self, validated_data):
return BasicObject(**validated_data) return BasicObject(**validated_data)
class ObjectListSerializer(serializers.ListSerializer): class ObjectListSerializer(ListSerializer):
child = TestSerializer() child = TestSerializer()
self.Serializer = ObjectListSerializer self.Serializer = ObjectListSerializer
@ -145,9 +151,9 @@ class TestNestedListSerializer:
""" """
def setup(self): def setup(self):
class TestSerializer(serializers.Serializer): class TestSerializer(Serializer):
integers = serializers.ListSerializer(child=serializers.IntegerField()) integers = ListSerializer(child=IntegerField())
booleans = serializers.ListSerializer(child=serializers.BooleanField()) booleans = ListSerializer(child=BooleanField())
def create(self, validated_data): def create(self, validated_data):
return BasicObject(**validated_data) return BasicObject(**validated_data)
@ -224,15 +230,15 @@ class TestNestedListSerializer:
class TestNestedListOfListsSerializer: class TestNestedListOfListsSerializer:
def setup(self): def setup(self):
class TestSerializer(serializers.Serializer): class TestSerializer(Serializer):
integers = serializers.ListSerializer( integers = ListSerializer(
child=serializers.ListSerializer( child=ListSerializer(
child=serializers.IntegerField() child=IntegerField()
) )
) )
booleans = serializers.ListSerializer( booleans = ListSerializer(
child=serializers.ListSerializer( child=ListSerializer(
child=serializers.BooleanField() child=BooleanField()
) )
) )
@ -277,12 +283,13 @@ class TestNestedListOfListsSerializer:
class TestListSerializerClass: class TestListSerializerClass:
"""Tests for a custom list_serializer_class.""" """Tests for a custom list_serializer_class."""
def test_list_serializer_class_validate(self):
class CustomListSerializer(serializers.ListSerializer):
def validate(self, attrs):
raise serializers.ValidationError('Non field error')
class TestSerializer(serializers.Serializer): def test_list_serializer_class_validate(self):
class CustomListSerializer(ListSerializer):
def validate(self, attrs):
raise ValidationError('Non field error')
class TestSerializer(Serializer):
class Meta: class Meta:
list_serializer_class = CustomListSerializer list_serializer_class = CustomListSerializer
@ -299,9 +306,11 @@ class TestSerializerPartialUsage:
Regression test for Github issue #2761. Regression test for Github issue #2761.
""" """
def test_partial_listfield(self): def test_partial_listfield(self):
class ListSerializer(serializers.Serializer): class ListSerializer(Serializer):
listdata = serializers.ListField() listdata = ListField()
serializer = ListSerializer(data=MultiValueDict(), partial=True) serializer = ListSerializer(data=MultiValueDict(), partial=True)
result = serializer.to_internal_value(data={}) result = serializer.to_internal_value(data={})
assert "listdata" not in result assert "listdata" not in result
@ -310,9 +319,11 @@ class TestSerializerPartialUsage:
assert serializer.errors == {} assert serializer.errors == {}
def test_partial_multiplechoice(self): def test_partial_multiplechoice(self):
class MultipleChoiceSerializer(serializers.Serializer): class MultipleChoiceSerializer(Serializer):
multiplechoice = serializers.MultipleChoiceField(choices=[1, 2, 3]) multiplechoice = MultipleChoiceField(choices=[1, 2, 3])
serializer = MultipleChoiceSerializer(data=MultiValueDict(), partial=True)
serializer = MultipleChoiceSerializer(data=MultiValueDict(),
partial=True)
result = serializer.to_internal_value(data={}) result = serializer.to_internal_value(data={})
assert "multiplechoice" not in result assert "multiplechoice" not in result
assert serializer.is_valid() assert serializer.is_valid()

View File

@ -1,15 +1,16 @@
from django.http import QueryDict from django.http import QueryDict
from rest_framework import serializers from rest_framework.fields import IntegerField, MultipleChoiceField
from rest_framework.serializers import ListSerializer, Serializer
class TestNestedSerializer: class TestNestedSerializer:
def setup(self): def setup(self):
class NestedSerializer(serializers.Serializer): class NestedSerializer(Serializer):
one = serializers.IntegerField(max_value=10) one = IntegerField(max_value=10)
two = serializers.IntegerField(max_value=10) two = IntegerField(max_value=10)
class TestSerializer(serializers.Serializer): class TestSerializer(Serializer):
nested = NestedSerializer() nested = NestedSerializer()
self.Serializer = TestSerializer self.Serializer = TestSerializer
@ -50,10 +51,10 @@ class TestNestedSerializer:
class TestNotRequiredNestedSerializer: class TestNotRequiredNestedSerializer:
def setup(self): def setup(self):
class NestedSerializer(serializers.Serializer): class NestedSerializer(Serializer):
one = serializers.IntegerField(max_value=10) one = IntegerField(max_value=10)
class TestSerializer(serializers.Serializer): class TestSerializer(Serializer):
nested = NestedSerializer(required=False) nested = NestedSerializer(required=False)
self.Serializer = TestSerializer self.Serializer = TestSerializer
@ -79,10 +80,10 @@ class TestNotRequiredNestedSerializer:
class TestNestedSerializerWithMany: class TestNestedSerializerWithMany:
def setup(self): def setup(self):
class NestedSerializer(serializers.Serializer): class NestedSerializer(Serializer):
example = serializers.IntegerField(max_value=10) example = IntegerField(max_value=10)
class TestSerializer(serializers.Serializer): class TestSerializer(Serializer):
allow_null = NestedSerializer(many=True, allow_null=True) allow_null = NestedSerializer(many=True, allow_null=True)
not_allow_null = NestedSerializer(many=True) not_allow_null = NestedSerializer(many=True)
allow_empty = NestedSerializer(many=True, allow_empty=True) allow_empty = NestedSerializer(many=True, allow_empty=True)
@ -119,7 +120,8 @@ class TestNestedSerializerWithMany:
assert not serializer.is_valid() assert not serializer.is_valid()
expected_errors = {'not_allow_null': [serializer.error_messages['null']]} expected_errors = {
'not_allow_null': [serializer.error_messages['null']]}
assert serializer.errors == expected_errors assert serializer.errors == expected_errors
def test_run_the_field_validation_even_if_the_field_is_null(self): def test_run_the_field_validation_even_if_the_field_is_null(self):
@ -171,16 +173,17 @@ class TestNestedSerializerWithMany:
assert not serializer.is_valid() assert not serializer.is_valid()
expected_errors = {'not_allow_empty': {'non_field_errors': [serializers.ListSerializer.default_error_messages['empty']]}} expected_errors = {'not_allow_empty': {'non_field_errors': [
ListSerializer.default_error_messages['empty']]}}
assert serializer.errors == expected_errors assert serializer.errors == expected_errors
class TestNestedSerializerWithList: class TestNestedSerializerWithList:
def setup(self): def setup(self):
class NestedSerializer(serializers.Serializer): class NestedSerializer(Serializer):
example = serializers.MultipleChoiceField(choices=[1, 2, 3]) example = MultipleChoiceField(choices=[1, 2, 3])
class TestSerializer(serializers.Serializer): class TestSerializer(Serializer):
nested = NestedSerializer() nested = NestedSerializer()
self.Serializer = TestSerializer self.Serializer = TestSerializer

View File

@ -43,7 +43,8 @@ urlpatterns = [
url(r'^resource/customname$', CustomNameResourceInstance.as_view()), url(r'^resource/customname$', CustomNameResourceInstance.as_view()),
url(r'^resource/(?P<key>[0-9]+)$', ResourceInstance.as_view()), url(r'^resource/(?P<key>[0-9]+)$', ResourceInstance.as_view()),
url(r'^resource/(?P<key>[0-9]+)/$', NestedResourceRoot.as_view()), url(r'^resource/(?P<key>[0-9]+)/$', NestedResourceRoot.as_view()),
url(r'^resource/(?P<key>[0-9]+)/(?P<other>[A-Za-z]+)$', NestedResourceInstance.as_view()), url(r'^resource/(?P<key>[0-9]+)/(?P<other>[A-Za-z]+)$',
NestedResourceInstance.as_view()),
] ]
@ -52,6 +53,7 @@ class BreadcrumbTests(TestCase):
""" """
Tests the breadcrumb functionality used by the HTML renderer. Tests the breadcrumb functionality used by the HTML renderer.
""" """
def test_root_breadcrumbs(self): def test_root_breadcrumbs(self):
url = '/' url = '/'
self.assertEqual( self.assertEqual(
@ -130,6 +132,7 @@ class ResolveModelTests(TestCase):
provided argument is a Django model class itself, or a properly provided argument is a Django model class itself, or a properly
formatted string representation of one. formatted string representation of one.
""" """
def test_resolve_django_model(self): def test_resolve_django_model(self):
resolved_model = _resolve_model(BasicModel) resolved_model = _resolve_model(BasicModel)
self.assertEqual(resolved_model, BasicModel) self.assertEqual(resolved_model, BasicModel)

View File

@ -1,4 +1,5 @@
from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ObjectDoesNotExist
from rest_framework.compat import NoReverseMatch from rest_framework.compat import NoReverseMatch