PEP8 and pyflakes fixes

This commit is contained in:
Adam Nelson 2016-10-21 11:20:10 -04:00
parent 28b801bc1b
commit 556cf1205d
25 changed files with 1467 additions and 913 deletions

View File

@ -8,14 +8,20 @@ to return this information in a more standardized way.
"""
from __future__ import unicode_literals
from collections import OrderedDict
from django.core.exceptions import PermissionDenied
from django.http import Http404
from django.utils.encoding import force_text
from rest_framework import exceptions, serializers
from rest_framework import exceptions
from rest_framework.fields import (
BooleanField, CharField, ChoiceField, DateField, DateTimeField,
DecimalField, DictField, EmailField, Field, FileField, FloatField,
ImageField, IntegerField, ListField, MultipleChoiceField, NullBooleanField,
OrderedDict, RegexField, SlugField, TimeField, URLField
)
from rest_framework.relations import ManyRelatedField, RelatedField
from rest_framework.request import clone_request
from rest_framework.serializers import Serializer
from rest_framework.utils.field_mapping import ClassLookupDict
@ -36,35 +42,37 @@ class SimpleMetadata(BaseMetadata):
for us to base this on.
"""
label_lookup = ClassLookupDict({
serializers.Field: 'field',
serializers.BooleanField: 'boolean',
serializers.NullBooleanField: 'boolean',
serializers.CharField: 'string',
serializers.URLField: 'url',
serializers.EmailField: 'email',
serializers.RegexField: 'regex',
serializers.SlugField: 'slug',
serializers.IntegerField: 'integer',
serializers.FloatField: 'float',
serializers.DecimalField: 'decimal',
serializers.DateField: 'date',
serializers.DateTimeField: 'datetime',
serializers.TimeField: 'time',
serializers.ChoiceField: 'choice',
serializers.MultipleChoiceField: 'multiple choice',
serializers.FileField: 'file upload',
serializers.ImageField: 'image upload',
serializers.ListField: 'list',
serializers.DictField: 'nested object',
serializers.Serializer: 'nested object',
Field: 'field',
BooleanField: 'boolean',
NullBooleanField: 'boolean',
CharField: 'string',
URLField: 'url',
EmailField: 'email',
RegexField: 'regex',
SlugField: 'slug',
IntegerField: 'integer',
FloatField: 'float',
DecimalField: 'decimal',
DateField: 'date',
DateTimeField: 'datetime',
TimeField: 'time',
ChoiceField: 'choice',
MultipleChoiceField: 'multiple choice',
FileField: 'file upload',
ImageField: 'image upload',
ListField: 'list',
DictField: 'nested object',
Serializer: 'nested object',
})
def determine_metadata(self, request, view):
metadata = OrderedDict()
metadata['name'] = view.get_view_name()
metadata['description'] = view.get_view_description()
metadata['renders'] = [renderer.media_type for renderer in view.renderer_classes]
metadata['parses'] = [parser.media_type for parser in view.parser_classes]
metadata['renders'] = [renderer.media_type for renderer in
view.renderer_classes]
metadata['parses'] = [parser.media_type for parser in
view.parser_classes]
if hasattr(view, 'get_serializer'):
actions = self.determine_actions(request, view)
if actions:
@ -90,7 +98,7 @@ class SimpleMetadata(BaseMetadata):
pass
else:
# If user has appropriate permissions for the view, include
# appropriate metadata about the fields that should be supplied.
# appropriate metadata about the fields that should be supplied
serializer = view.get_serializer()
actions[method] = self.get_serializer_info(serializer)
finally:
@ -107,10 +115,9 @@ class SimpleMetadata(BaseMetadata):
# If this is a `ListSerializer` then we want to examine the
# underlying child serializer instance instead.
serializer = serializer.child
return OrderedDict([
(field_name, self.get_field_info(field))
for field_name, field in serializer.fields.items()
])
return OrderedDict([(field_name, self.get_field_info(field))
for field_name, field in
serializer.fields.items()])
def get_field_info(self, field):
"""
@ -138,14 +145,12 @@ class SimpleMetadata(BaseMetadata):
field_info['children'] = self.get_serializer_info(field)
if (not field_info.get('read_only') and
not isinstance(field, (serializers.RelatedField, serializers.ManyRelatedField)) and
not isinstance(field, (RelatedField, ManyRelatedField)) and
hasattr(field, 'choices')):
field_info['choices'] = [
{
'value': choice_value,
'display_name': force_text(choice_name, strings_only=True)
}
for choice_value, choice_name in field.choices.items()
]
field_info['choices'] = [{'value': choice_value,
'display_name': force_text(
choice_name, strings_only=True)}
for choice_value, choice_name in
field.choices.items()]
return field_info

View File

@ -7,7 +7,6 @@ from django.http import Http404
from rest_framework.compat import is_authenticated
SAFE_METHODS = ('GET', 'HEAD', 'OPTIONS')

View File

@ -9,24 +9,30 @@ REST framework also provides an HTML renderer that renders the browsable API.
from __future__ import unicode_literals
import json
from collections import OrderedDict
from django import forms
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.core.paginator import Page
from django.http.multipartparser import parse_header
from django.template import Template, loader
from django.test.client import encode_multipart
from django.utils import six
from rest_framework import VERSION, exceptions, serializers, status
from rest_framework import VERSION, exceptions, status
from rest_framework.compat import (
INDENT_SEPARATORS, LONG_SEPARATORS, SHORT_SEPARATORS, coreapi,
template_render
)
from rest_framework.exceptions import ParseError
from rest_framework.fields import (
BooleanField, ChoiceField, DateField, DateTimeField, EmailField, Field,
FileField, FilePathField, FloatField, HiddenField, IntegerField,
MultipleChoiceField, OrderedDict, TimeField, URLField, six
)
from rest_framework.relations import (
ImproperlyConfigured, ManyRelatedField, RelatedField
)
from rest_framework.request import is_form_media_type, override_method
from rest_framework.serializers import ListSerializer, Serializer
from rest_framework.settings import api_settings
from rest_framework.utils import encoders
from rest_framework.utils.breadcrumbs import get_breadcrumbs
@ -48,7 +54,8 @@ class BaseRenderer(object):
render_style = 'text'
def render(self, data, accepted_media_type=None, renderer_context=None):
raise NotImplementedError('Renderer class requires .render() to be implemented')
raise NotImplementedError(
'Renderer class requires .render() to be implemented')
class JSONRenderer(BaseRenderer):
@ -64,7 +71,7 @@ class JSONRenderer(BaseRenderer):
# We don't set a charset because JSON is a binary encoding,
# that can be encoded as utf-8, utf-16 or utf-32.
# See: http://www.ietf.org/rfc/rfc4627.txt
# Also: http://lucumr.pocoo.org/2013/7/19/application-mimetypes-and-encodings/
# http://lucumr.pocoo.org/2013/7/19/application-mimetypes-and-encodings/
charset = None
def get_indent(self, accepted_media_type, renderer_context):
@ -72,7 +79,8 @@ class JSONRenderer(BaseRenderer):
# If the media type looks like 'application/json; indent=4',
# then pretty print the result.
# Note that we coerce `indent=0` into `indent=None`.
base_media_type, params = parse_header(accepted_media_type.encode('ascii'))
base_media_type, params = parse_header(
accepted_media_type.encode('ascii'))
try:
return zero_as_none(max(min(int(params['indent']), 8), 0))
except (KeyError, ValueError, TypeError):
@ -192,7 +200,8 @@ class TemplateHTMLRenderer(BaseRenderer):
elif hasattr(view, 'template_name'):
return [view.template_name]
raise ImproperlyConfigured(
'Returned a template response with no `template_name` attribute set on either the view or response'
'Returned a template response with no `template_name` attribute '
'set on either the view or response'
)
def get_exception_template(self, response):
@ -260,88 +269,93 @@ class HTMLFormRenderer(BaseRenderer):
base_template = 'form.html'
default_style = ClassLookupDict({
serializers.Field: {
Field: {
'base_template': 'input.html',
'input_type': 'text'
},
serializers.EmailField: {
EmailField: {
'base_template': 'input.html',
'input_type': 'email'
},
serializers.URLField: {
URLField: {
'base_template': 'input.html',
'input_type': 'url'
},
serializers.IntegerField: {
IntegerField: {
'base_template': 'input.html',
'input_type': 'number'
},
serializers.FloatField: {
FloatField: {
'base_template': 'input.html',
'input_type': 'number'
},
serializers.DateTimeField: {
DateTimeField: {
'base_template': 'input.html',
'input_type': 'datetime-local'
},
serializers.DateField: {
DateField: {
'base_template': 'input.html',
'input_type': 'date'
},
serializers.TimeField: {
TimeField: {
'base_template': 'input.html',
'input_type': 'time'
},
serializers.FileField: {
FileField: {
'base_template': 'input.html',
'input_type': 'file'
},
serializers.BooleanField: {
BooleanField: {
'base_template': 'checkbox.html'
},
serializers.ChoiceField: {
ChoiceField: {
'base_template': 'select.html', # Also valid: 'radio.html'
},
serializers.MultipleChoiceField: {
'base_template': 'select_multiple.html', # Also valid: 'checkbox_multiple.html'
MultipleChoiceField: {
'base_template': 'select_multiple.html',
# Also valid: 'checkbox_multiple.html'
},
serializers.RelatedField: {
RelatedField: {
'base_template': 'select.html', # Also valid: 'radio.html'
},
serializers.ManyRelatedField: {
'base_template': 'select_multiple.html', # Also valid: 'checkbox_multiple.html'
ManyRelatedField: {
'base_template': 'select_multiple.html',
# Also valid: 'checkbox_multiple.html'
},
serializers.Serializer: {
Serializer: {
'base_template': 'fieldset.html'
},
serializers.ListSerializer: {
ListSerializer: {
'base_template': 'list_fieldset.html'
},
serializers.FilePathField: {
FilePathField: {
'base_template': 'select.html',
},
})
def render_field(self, field, parent_style):
if isinstance(field._field, serializers.HiddenField):
if isinstance(field._field, HiddenField):
return ''
style = dict(self.default_style[field])
style.update(field.style)
if 'template_pack' not in style:
style['template_pack'] = parent_style.get('template_pack', self.template_pack)
style['template_pack'] = parent_style.get('template_pack',
self.template_pack)
style['renderer'] = self
# Get a clone of the field with text-only value representation.
field = field.as_form_field()
if style.get('input_type') == 'datetime-local' and isinstance(field.value, six.text_type):
if style.get('input_type') == 'datetime-local' and isinstance(
field.value, six.text_type):
field.value = field.value.rstrip('Z')
if 'template' in style:
template_name = style['template']
else:
template_name = style['template_pack'].strip('/') + '/' + style['base_template']
template_name = style['template_pack'].strip('/') + '/' + style[
'base_template']
template = loader.get_template(template_name)
context = {'field': field, 'style': style}
@ -388,7 +402,8 @@ class BrowsableAPIRenderer(BaseRenderer):
renderers = [renderer for renderer in view.renderer_classes
if not issubclass(renderer, BrowsableAPIRenderer)]
non_template_renderers = [renderer for renderer in renderers
if not hasattr(renderer, 'get_template_names')]
if
not hasattr(renderer, 'get_template_names')]
if not renderers:
return None
@ -410,7 +425,8 @@ class BrowsableAPIRenderer(BaseRenderer):
render_style = getattr(renderer, 'render_style', 'text')
assert render_style in ['text', 'binary'], 'Expected .render_style ' \
'"text" or "binary", but got "%s"' % render_style
'"text" or "binary", but ' \
'got "%s"' % render_style
if render_style == 'binary':
return '[%d bytes of binary content]' % len(content)
@ -431,7 +447,8 @@ class BrowsableAPIRenderer(BaseRenderer):
return False # Doesn't have permissions
return True
def _get_serializer(self, serializer_class, view_instance, request, *args, **kwargs):
def _get_serializer(self, serializer_class, view_instance, request, *args,
**kwargs):
kwargs['context'] = {
'request': request,
'format': self.format,
@ -478,10 +495,10 @@ class BrowsableAPIRenderer(BaseRenderer):
has_serializer = getattr(view, 'get_serializer', None)
has_serializer_class = getattr(view, 'serializer_class', None)
if (
(not has_serializer and not has_serializer_class) or
not any(is_form_media_type(parser.media_type) for parser in view.parser_classes)
):
if ((not has_serializer and not has_serializer_class) or
not any(
is_form_media_type(parser.media_type) for parser in
view.parser_classes)):
return
if existing_serializer is not None:
@ -492,16 +509,21 @@ class BrowsableAPIRenderer(BaseRenderer):
if has_serializer:
if method in ('PUT', 'PATCH'):
serializer = view.get_serializer(instance=instance, **kwargs)
serializer = view.get_serializer(instance=instance,
**kwargs)
else:
serializer = view.get_serializer(**kwargs)
else:
# at this point we must have a serializer_class
if method in ('PUT', 'PATCH'):
serializer = self._get_serializer(view.serializer_class, view,
request, instance=instance, **kwargs)
serializer = self._get_serializer(view.serializer_class,
view,
request,
instance=instance,
**kwargs)
else:
serializer = self._get_serializer(view.serializer_class, view,
serializer = self._get_serializer(view.serializer_class,
view,
request, **kwargs)
return self.render_form_for_serializer(serializer)
@ -569,7 +591,8 @@ class BrowsableAPIRenderer(BaseRenderer):
label='Media type',
choices=choices,
initial=initial,
widget=forms.Select(attrs={'data-override': 'content-type'})
widget=forms.Select(
attrs={'data-override': 'content-type'})
)
_content = forms.CharField(
label='Content',
@ -583,7 +606,8 @@ class BrowsableAPIRenderer(BaseRenderer):
return view.get_view_name()
def get_description(self, view, status_code):
if status_code in (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN):
if status_code in (
status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN):
return ''
return view.get_view_description(html=True)
@ -591,7 +615,8 @@ class BrowsableAPIRenderer(BaseRenderer):
return get_breadcrumbs(request.path, request)
def get_filter_form(self, data, view, request):
if not hasattr(view, 'get_queryset') or not hasattr(view, 'filter_backends'):
if not hasattr(view, 'get_queryset') or not hasattr(view,
'filter_backends'):
return
# Infer if this is a list view or not.
@ -631,9 +656,11 @@ class BrowsableAPIRenderer(BaseRenderer):
renderer = self.get_default_renderer(view)
raw_data_post_form = self.get_raw_data_form(data, view, 'POST', request)
raw_data_post_form = self.get_raw_data_form(data, view, 'POST',
request)
raw_data_put_form = self.get_raw_data_form(data, view, 'PUT', request)
raw_data_patch_form = self.get_raw_data_form(data, view, 'PATCH', request)
raw_data_patch_form = self.get_raw_data_form(data, view, 'PATCH',
request)
raw_data_put_or_patch_form = raw_data_put_form or raw_data_patch_form
response_headers = OrderedDict(sorted(response.items()))
@ -644,19 +671,22 @@ class BrowsableAPIRenderer(BaseRenderer):
renderer_content_type += ' ;%s' % renderer.charset
response_headers['Content-Type'] = renderer_content_type
if getattr(view, 'paginator', None) and view.paginator.display_page_controls:
if getattr(view, 'paginator',
None) and view.paginator.display_page_controls:
paginator = view.paginator
else:
paginator = None
csrf_cookie_name = settings.CSRF_COOKIE_NAME
csrf_header_name = getattr(settings, 'CSRF_HEADER_NAME', 'HTTP_X_CSRFToken') # Fallback for Django 1.8
csrf_header_name = getattr(settings, 'CSRF_HEADER_NAME',
'HTTP_X_CSRFToken') # Fallback Django 1.8
if csrf_header_name.startswith('HTTP_'):
csrf_header_name = csrf_header_name[5:]
csrf_header_name = csrf_header_name.replace('_', '-')
context = {
'content': self.get_content(renderer, data, accepted_media_type, renderer_context),
'content': self.get_content(renderer, data, accepted_media_type,
renderer_context),
'view': view,
'request': request,
'response': response,
@ -667,13 +697,18 @@ class BrowsableAPIRenderer(BaseRenderer):
'paginator': paginator,
'breadcrumblist': self.get_breadcrumbs(request),
'allowed_methods': view.allowed_methods,
'available_formats': [renderer_cls.format for renderer_cls in view.renderer_classes],
'available_formats': [renderer_cls.format for renderer_cls in
view.renderer_classes],
'response_headers': response_headers,
'put_form': self.get_rendered_html_form(data, view, 'PUT', request),
'post_form': self.get_rendered_html_form(data, view, 'POST', request),
'delete_form': self.get_rendered_html_form(data, view, 'DELETE', request),
'options_form': self.get_rendered_html_form(data, view, 'OPTIONS', request),
'put_form': self.get_rendered_html_form(data, view, 'PUT',
request),
'post_form': self.get_rendered_html_form(data, view, 'POST',
request),
'delete_form': self.get_rendered_html_form(data, view, 'DELETE',
request),
'options_form': self.get_rendered_html_form(data, view, 'OPTIONS',
request),
'filter_form': self.get_filter_form(data, view, request),
@ -699,7 +734,8 @@ class BrowsableAPIRenderer(BaseRenderer):
template = loader.get_template(self.template)
context = self.get_context(data, accepted_media_type, renderer_context)
ret = template_render(template, context, request=renderer_context['request'])
ret = template_render(template, context,
request=renderer_context['request'])
# Munge DELETE Response code to allow us to return content
# (Do this *after* we've rendered the template so that we include
@ -726,8 +762,11 @@ class AdminRenderer(BrowsableAPIRenderer):
if response.status_code == status.HTTP_400_BAD_REQUEST:
# Errors still need to display the list or detail information.
# The only way we can get at that is to simulate a GET request.
self.error_form = self.get_rendered_html_form(data, view, request.method, request)
self.error_title = {'POST': 'Create', 'PUT': 'Edit'}.get(request.method, 'Errors')
self.error_form = self.get_rendered_html_form(data, view,
request.method,
request)
self.error_title = {'POST': 'Create', 'PUT': 'Edit'}.get(
request.method, 'Errors')
with override_method(view, request, 'GET') as request:
response = view.get(request, *view.args, **view.kwargs)
@ -735,10 +774,12 @@ class AdminRenderer(BrowsableAPIRenderer):
template = loader.get_template(self.template)
context = self.get_context(data, accepted_media_type, renderer_context)
ret = template_render(template, context, request=renderer_context['request'])
ret = template_render(template, context,
request=renderer_context['request'])
# Creation and deletion should use redirects in the admin style.
if (response.status_code == status.HTTP_201_CREATED) and ('Location' in response):
if (response.status_code == status.HTTP_201_CREATED) \
and ('Location' in response):
response.status_code = status.HTTP_303_SEE_OTHER
response['Location'] = request.build_absolute_uri()
ret = ''
@ -818,7 +859,8 @@ class CoreJSONRenderer(BaseRenderer):
format = 'corejson'
def __init__(self):
assert coreapi, 'Using CoreJSONRenderer, but `coreapi` is not installed.'
assert coreapi, 'Using CoreJSONRenderer, but `coreapi` is not ' \
'installed.'
def render(self, data, media_type=None, renderer_context=None):
indent = bool(renderer_context.get('indent', 0))

View File

@ -5,35 +5,39 @@ from importlib import import_module
from django.conf import settings
from django.contrib.admindocs.views import simplify_regex
from django.utils import six
from django.utils.encoding import force_text, smart_text
from rest_framework import exceptions, renderers, serializers
from rest_framework import exceptions, renderers
from rest_framework.compat import (
RegexURLPattern, RegexURLResolver, coreapi, uritemplate, urlparse
)
from rest_framework.fields import (
BooleanField, DecimalField, Field, FileField, FloatField, HiddenField,
IntegerField, MultipleChoiceField, six
)
from rest_framework.relations import ManyRelatedField
from rest_framework.request import clone_request
from rest_framework.response import Response
from rest_framework.serializers import ListSerializer, Serializer
from rest_framework.settings import api_settings
from rest_framework.utils import formatting
from rest_framework.utils.field_mapping import ClassLookupDict
from rest_framework.utils.model_meta import _get_pk
from rest_framework.views import APIView
header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:')
types_lookup = ClassLookupDict({
serializers.Field: 'string',
serializers.IntegerField: 'integer',
serializers.FloatField: 'number',
serializers.DecimalField: 'number',
serializers.BooleanField: 'boolean',
serializers.FileField: 'file',
serializers.MultipleChoiceField: 'array',
serializers.ManyRelatedField: 'array',
serializers.Serializer: 'object',
serializers.ListSerializer: 'array'
Field: 'string',
IntegerField: 'integer',
FloatField: 'number',
DecimalField: 'number',
BooleanField: 'boolean',
FileField: 'file',
MultipleChoiceField: 'array',
ManyRelatedField: 'array',
Serializer: 'object',
ListSerializer: 'array'
})
@ -104,6 +108,7 @@ class EndpointInspector(object):
"""
A class to determine the available API endpoints that a project exposes.
"""
def __init__(self, patterns=None, urlconf=None):
if patterns is None:
if urlconf is None:
@ -176,10 +181,8 @@ class EndpointInspector(object):
if hasattr(callback, 'actions'):
return [method.upper() for method in callback.actions.keys()]
return [
method for method in
callback.cls().allowed_methods if method not in ('OPTIONS', 'HEAD')
]
return [method for method in callback.cls().allowed_methods
if method not in ('OPTIONS', 'HEAD')]
class SchemaGenerator(object):
@ -193,8 +196,9 @@ class SchemaGenerator(object):
}
endpoint_inspector_cls = EndpointInspector
# Map the method names we use for viewset actions onto external schema names.
# These give us names that are more suitable for the external representation.
# Map the method names we use for viewset actions onto external schema
# names. These give us names that are more suitable for the external
# representation.
# Set by 'SCHEMA_COERCE_METHOD_NAMES'.
coerce_method_names = None
@ -223,7 +227,8 @@ class SchemaGenerator(object):
Generate a `coreapi.Document` representing the API schema.
"""
if self.endpoints is None:
inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf)
inspector = self.endpoint_inspector_cls(self.patterns,
self.urlconf)
self.endpoints = inspector.get_api_endpoints()
links = self.get_links(request)
@ -358,7 +363,8 @@ class SchemaGenerator(object):
fields += self.get_pagination_fields(path, method, view)
fields += self.get_filter_fields(path, method, view)
if fields and any([field.location in ('form', 'body') for field in fields]):
if fields and any(
[field.location in ('form', 'body') for field in fields]):
encoding = self.get_encoding(path, method, view)
else:
encoding = None
@ -438,7 +444,8 @@ class SchemaGenerator(object):
fields = []
for variable in uritemplate.variables(path):
field = coreapi.Field(name=variable, location='path', required=True)
field = coreapi.Field(name=variable, location='path',
required=True)
fields.append(field)
return fields
@ -456,7 +463,7 @@ class SchemaGenerator(object):
serializer = view.get_serializer()
if isinstance(serializer, serializers.ListSerializer):
if isinstance(serializer, ListSerializer):
return [
coreapi.Field(
name='data',
@ -466,16 +473,17 @@ class SchemaGenerator(object):
)
]
if not isinstance(serializer, serializers.Serializer):
if not isinstance(serializer, Serializer):
return []
fields = []
for field in serializer.fields.values():
if field.read_only or isinstance(field, serializers.HiddenField):
if field.read_only or isinstance(field, HiddenField):
continue
required = field.required and method != 'PATCH'
description = force_text(field.help_text) if field.help_text else ''
description = force_text(
field.help_text) if field.help_text else ''
field = coreapi.Field(
name=field.field_name,
location='form',
@ -517,27 +525,30 @@ class SchemaGenerator(object):
the schema document.
/users/ ("users", "list"), ("users", "create")
/users/{pk}/ ("users", "read"), ("users", "update"), ("users", "delete")
/users/enabled/ ("users", "enabled") # custom viewset list action
/users/{pk}/star/ ("users", "star") # custom viewset detail action
/users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create")
/users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update"), ("users", "groups", "delete")
/users/{pk}/ ("users", "read"), ("users", "update"),
("users", "delete")
/users/enabled/ ("users", "enabled") # custom viewset list
/users/{pk}/star/ ("users", "star") # custom viewset detail
/users/{pk}/groups/ ("users", "groups", "list"),
("users", "groups", "create")
/users/{pk}/groups/{pk}/ ("users", "groups", "read"),
("users", "groups", "update"),
("users", "groups", "delete")
"""
if hasattr(view, 'action'):
# Viewsets have explicitly named actions.
action = view.action
else:
# Views have no associated action, so we determine one from the method.
# Views have no associated action, so we determine one from the
# method.
if is_list_view(subpath, method, view):
action = 'list'
else:
action = self.default_mapping[method.lower()]
named_path_components = [
component for component
in subpath.strip('/').split('/')
if '{' not in component
]
named_path_components = [component for component
in subpath.strip('/').split('/')
if '{' not in component]
if is_custom_action(action):
# Custom action, eg "/users/{pk}/activate/", "/users/active/"
@ -562,8 +573,10 @@ def get_schema_view(title=None, url=None, renderer_classes=None):
"""
generator = SchemaGenerator(title=title, url=url)
if renderer_classes is None:
if renderers.BrowsableAPIRenderer in api_settings.DEFAULT_RENDERER_CLASSES:
rclasses = [renderers.CoreJSONRenderer, renderers.BrowsableAPIRenderer]
if renderers.BrowsableAPIRenderer in \
api_settings.DEFAULT_RENDERER_CLASSES:
rclasses = [renderers.CoreJSONRenderer,
renderers.BrowsableAPIRenderer]
else:
rclasses = [renderers.CoreJSONRenderer]
else:

View File

@ -12,7 +12,9 @@ response content is handled by parsers and renderers.
"""
from __future__ import unicode_literals
import inspect
import traceback
from copy import deepcopy
from django.db import models
from django.db.models import DurationField as ModelDurationField
@ -23,6 +25,20 @@ from django.utils.translation import ugettext_lazy as _
from rest_framework.compat import JSONField as ModelJSONField
from rest_framework.compat import postgres_fields, set_many, unicode_to_repr
from rest_framework.fields import (
BooleanField, CharField, ChoiceField, CreateOnlyDefault, DateField,
DateTimeField, DecimalField, DictField, DjangoValidationError,
DurationField, EmailField, ErrorDetail, Field, FileField, FilePathField,
FloatField, HiddenField, ImageField, IntegerField, IPAddressField,
JSONField, ListField, ModelField, NullBooleanField, OrderedDict,
ReadOnlyField, SkipField, SlugField, TimeField, URLField, UUIDField,
ValidationError, api_settings, empty, get_error_detail, html,
representation, set_value, six, timezone
)
from rest_framework.relations import (
HyperlinkedIdentityField, HyperlinkedRelatedField, ImproperlyConfigured,
PKOnlyObject, PrimaryKeyRelatedField, SlugRelatedField
)
from rest_framework.utils import model_meta
from rest_framework.utils.field_mapping import (
ClassLookupDict, get_field_kwargs, get_nested_relation_kwargs,
@ -36,16 +52,6 @@ from rest_framework.validators import (
UniqueTogetherValidator
)
# Note: We do the following so that users of the framework can use this style:
#
# example_field = serializers.CharField(...)
#
# This helps keep the separation between model fields, form fields, and
# serializer fields more explicit.
from rest_framework.fields import * # NOQA # isort:skip
from rest_framework.relations import * # NOQA # isort:skip
# We assume that 'validators' are intended for the child serializer,
# rather than the parent serializer.
LIST_SERIALIZER_KWARGS = (
@ -125,12 +131,11 @@ class BaseSerializer(Field):
}
if allow_empty is not None:
list_kwargs['allow_empty'] = allow_empty
list_kwargs.update({
key: value for key, value in kwargs.items()
if key in LIST_SERIALIZER_KWARGS
})
list_kwargs.update({key: value for key, value in kwargs.items()
if key in LIST_SERIALIZER_KWARGS})
meta = getattr(cls, 'Meta', None)
list_serializer_class = getattr(meta, 'list_serializer_class', ListSerializer)
list_serializer_class = getattr(meta, 'list_serializer_class',
ListSerializer)
return list_serializer_class(*args, **list_kwargs)
def to_internal_value(self, data):
@ -164,17 +169,17 @@ class BaseSerializer(Field):
# Guard against incorrect use of `serializer.save(commit=False)`
assert 'commit' not in kwargs, (
"'commit' is not a valid keyword argument to the 'save()' method. "
"If you need to access data before committing to the database then "
"inspect 'serializer.validated_data' instead. "
"You can also pass additional keyword arguments to 'save()' if you "
"need to set extra attributes on the saved model instance. "
"If you need to access data before committing to the database then"
" inspect 'serializer.validated_data' instead. "
"You can also pass additional keyword arguments to 'save()' if you"
" need to set extra attributes on the saved model instance. "
"For example: 'serializer.save(owner=request.user)'.'"
)
assert not hasattr(self, '_data'), (
"You cannot call `.save()` after accessing `serializer.data`."
"If you need to access data before committing to the database then "
"inspect 'serializer.validated_data' instead. "
"If you need to access data before committing to the database then"
" inspect 'serializer.validated_data' instead. "
)
validated_data = dict(
@ -224,7 +229,8 @@ class BaseSerializer(Field):
@property
def data(self):
if hasattr(self, 'initial_data') and not hasattr(self, '_validated_data'):
if hasattr(self, 'initial_data') and not hasattr(self,
'_validated_data'):
msg = (
'When a serializer is passed a `data` keyword argument you '
'must call `.is_valid()` before attempting to access the '
@ -235,9 +241,12 @@ class BaseSerializer(Field):
raise AssertionError(msg)
if not hasattr(self, '_data'):
if self.instance is not None and not getattr(self, '_errors', None):
if self.instance is not None and not getattr(self, '_errors',
None):
self._data = self.to_representation(self.instance)
elif hasattr(self, '_validated_data') and not getattr(self, '_errors', None):
elif hasattr(self, '_validated_data') and not getattr(self,
'_errors',
None):
self._data = self.to_representation(self.validated_data)
else:
self._data = self.get_initial()
@ -253,7 +262,8 @@ class BaseSerializer(Field):
@property
def validated_data(self):
if not hasattr(self, '_validated_data'):
msg = 'You must call `.is_valid()` before accessing `.validated_data`.'
msg = 'You must call `.is_valid()` before accessing ' \
'`.validated_data`.'
raise AssertionError(msg)
return self._validated_data
@ -277,9 +287,9 @@ class SerializerMetaclass(type):
if isinstance(obj, Field)]
fields.sort(key=lambda x: x[1]._creation_counter)
# If this class is subclassing another Serializer, add that Serializer's
# fields. Note that we loop over the bases in *reverse*. This is necessary
# in order to maintain the correct order of fields.
# If this class is subclassing another Serializer, add that
# Serializer's fields. Note that we loop over the bases in *reverse*.
# This is necessary in order to maintain the correct order of fields.
for base in reversed(bases):
if hasattr(base, '_declared_fields'):
fields = list(base._declared_fields.items()) + fields
@ -302,10 +312,8 @@ def as_serializer_error(exc):
if isinstance(detail, dict):
# If errors may be a dict we use the standard {key: list of values}.
# Here we ensure that all the values are *lists* of errors.
return {
key: value if isinstance(value, (list, dict)) else [value]
for key, value in detail.items()
}
return {key: value if isinstance(value, (list, dict)) else [value]
for key, value in detail.items()}
elif isinstance(detail, list):
# Errors raised as a list are non-field errors.
return {
@ -320,7 +328,8 @@ def as_serializer_error(exc):
@six.add_metaclass(SerializerMetaclass)
class Serializer(BaseSerializer):
default_error_messages = {
'invalid': _('Invalid data. Expected a dictionary, but got {datatype}.')
'invalid': _(
'Invalid data. Expected a dictionary, but got {datatype}.')
}
@property
@ -339,17 +348,13 @@ class Serializer(BaseSerializer):
@cached_property
def _writable_fields(self):
return [
field for field in self.fields.values()
if (not field.read_only) or (field.default is not empty)
]
return [field for field in self.fields.values()
if (not field.read_only) or (field.default is not empty)]
@cached_property
def _readable_fields(self):
return [
field for field in self.fields.values()
if not field.write_only
]
return [field for field in self.fields.values()
if not field.write_only]
def get_fields(self):
"""
@ -358,7 +363,7 @@ class Serializer(BaseSerializer):
# Every new serializer is created with a clone of the field instances.
# This allows users to dynamically modify the fields on a serializer
# instance without affecting every other serializer class.
return copy.deepcopy(self._declared_fields)
return deepcopy(self._declared_fields)
def get_validators(self):
"""
@ -371,24 +376,22 @@ class Serializer(BaseSerializer):
def get_initial(self):
if hasattr(self, 'initial_data'):
return OrderedDict([
(field_name, field.get_value(self.initial_data))
for field_name, field in self.fields.items()
if (field.get_value(self.initial_data) is not empty) and
not field.read_only
])
return OrderedDict([(field_name,
field.get_value(self.initial_data))
for field_name, field in self.fields.items()
if (field.get_value(self.initial_data)
is not empty) and not field.read_only])
return OrderedDict([
(field.field_name, field.get_initial())
for field in self.fields.values()
if not field.read_only
])
return OrderedDict([(field.field_name, field.get_initial())
for field in self.fields.values()
if not field.read_only])
def get_value(self, dictionary):
# We override the default field access in order to support
# nested HTML forms.
if html.is_html_input(dictionary):
return html.parse_html_dict(dictionary, prefix=self.field_name) or empty
return html.parse_html_dict(dictionary,
prefix=self.field_name) or empty
return dictionary.get(self.field_name, empty)
def run_validation(self, data=empty):
@ -405,7 +408,8 @@ class Serializer(BaseSerializer):
try:
self.run_validators(value)
value = self.validate(value)
assert value is not None, '.validate() should return the validated data'
assert value is not None, '.validate() should return the ' \
'validated data'
except (ValidationError, DjangoValidationError) as exc:
raise ValidationError(detail=as_serializer_error(exc))
@ -428,7 +432,8 @@ class Serializer(BaseSerializer):
fields = self._writable_fields
for field in fields:
validate_method = getattr(self, 'validate_' + field.field_name, None)
validate_method = getattr(self, 'validate_' + field.field_name,
None)
primitive_value = field.get_value(data)
try:
validated_value = field.run_validation(primitive_value)
@ -466,7 +471,8 @@ class Serializer(BaseSerializer):
#
# For related fields with `use_pk_only_optimization` we need to
# resolve the pk value.
check_for_none = attribute.pk if isinstance(attribute, PKOnlyObject) else attribute
check_for_none = attribute.pk if \
isinstance(attribute, PKOnlyObject) else attribute
if check_for_none is None:
ret[field.field_name] = None
else:
@ -523,15 +529,17 @@ class ListSerializer(BaseSerializer):
many = True
default_error_messages = {
'not_a_list': _('Expected a list of items but got type "{input_type}".'),
'not_a_list': _(
'Expected a list of items but got type "{input_type}".'),
'empty': _('This list may not be empty.')
}
def __init__(self, *args, **kwargs):
self.child = kwargs.pop('child', copy.deepcopy(self.child))
self.child = kwargs.pop('child', deepcopy(self.child))
self.allow_empty = kwargs.pop('allow_empty', True)
assert self.child is not None, '`child` is a required argument.'
assert not inspect.isclass(self.child), '`child` has not been instantiated.'
assert not inspect.isclass(
self.child), '`child` has not been instantiated.'
super(ListSerializer, self).__init__(*args, **kwargs)
self.child.bind(field_name='', parent=self)
@ -564,7 +572,8 @@ class ListSerializer(BaseSerializer):
try:
self.run_validators(value)
value = self.validate(value)
assert value is not None, '.validate() should return the validated data'
assert value is not None, '.validate() should return the ' \
'validated data'
except (ValidationError, DjangoValidationError) as exc:
raise ValidationError(detail=as_serializer_error(exc))
@ -616,9 +625,7 @@ class ListSerializer(BaseSerializer):
# so, first get a queryset from the Manager if needed
iterable = data.all() if isinstance(data, models.Manager) else data
return [
self.child.to_representation(item) for item in iterable
]
return [self.child.to_representation(item) for item in iterable]
def validate(self, attrs):
return attrs
@ -633,9 +640,7 @@ class ListSerializer(BaseSerializer):
)
def create(self, validated_data):
return [
self.child.create(attrs) for attrs in validated_data
]
return [self.child.create(attrs) for attrs in validated_data]
def save(self, **kwargs):
"""
@ -644,17 +649,15 @@ class ListSerializer(BaseSerializer):
# Guard against incorrect use of `serializer.save(commit=False)`
assert 'commit' not in kwargs, (
"'commit' is not a valid keyword argument to the 'save()' method. "
"If you need to access data before committing to the database then "
"inspect 'serializer.validated_data' instead. "
"You can also pass additional keyword arguments to 'save()' if you "
"need to set extra attributes on the saved model instance. "
"If you need to access data before committing to the database then"
" inspect 'serializer.validated_data' instead. "
"You can also pass additional keyword arguments to 'save()' if you"
" need to set extra attributes on the saved model instance. "
"For example: 'serializer.save(owner=request.user)'.'"
)
validated_data = [
dict(list(attrs.items()) + list(kwargs.items()))
for attrs in self.validated_data
]
validated_data = [dict(list(attrs.items()) + list(kwargs.items()))
for attrs in self.validated_data]
if self.instance is not None:
self.instance = self.update(self.instance, validated_data)
@ -769,8 +772,8 @@ def raise_errors_on_nested_writes(method_name, serializer, validated_data):
isinstance(validated_data[key], (list, dict))
for key, field in serializer.fields.items()
), (
'The `.{method_name}()` method does not support writable dotted-source '
'fields by default.\nWrite an explicit `.{method_name}()` method for '
'The `.{method_name}()` method does not support writable dotted-source'
' fields by default.\nWrite an explicit `.{method_name}()` method for '
'serializer `{module}.{class_name}`, or set `read_only=True` on '
'dotted-source serializer fields.'.format(
method_name=method_name,
@ -944,7 +947,7 @@ class ModelSerializer(Serializer):
'Cannot use ModelSerializer with Abstract Models.'
)
declared_fields = copy.deepcopy(self._declared_fields)
declared_fields = deepcopy(self._declared_fields)
model = getattr(self.Meta, 'model')
depth = getattr(self.Meta, 'depth', 0)
@ -1003,7 +1006,8 @@ class ModelSerializer(Serializer):
fields = getattr(self.Meta, 'fields', None)
exclude = getattr(self.Meta, 'exclude', None)
if fields and fields != ALL_FIELDS and not isinstance(fields, (list, tuple)):
if fields and fields != ALL_FIELDS and not isinstance(fields,
(list, tuple)):
raise TypeError(
'The `fields` option must be a list or tuple or "__all__". '
'Got %s.' % type(fields).__name__
@ -1043,7 +1047,8 @@ class ModelSerializer(Serializer):
# a subset of fields.
required_field_names = set(declared_fields)
for cls in self.__class__.__bases__:
required_field_names -= set(getattr(cls, '_declared_fields', []))
required_field_names -= set(
getattr(cls, '_declared_fields', []))
for field_name in required_field_names:
assert field_name in fields, (
@ -1101,7 +1106,8 @@ class ModelSerializer(Serializer):
if not nested_depth:
return self.build_relational_field(field_name, relation_info)
else:
return self.build_nested_field(field_name, relation_info, nested_depth)
return self.build_nested_field(field_name, relation_info,
nested_depth)
elif hasattr(model_class, field_name):
return self.build_property_field(field_name, model_class)
@ -1126,7 +1132,8 @@ class ModelSerializer(Serializer):
field_class = self.serializer_choice_field
# Some model fields may introduce kwargs that would not be valid
# for the choice field. We need to strip these out.
# Eg. models.DecimalField(max_digits=3, decimal_places=1, choices=DECIMAL_CHOICES)
# Eg. models.DecimalField(max_digits=3, decimal_places=1, \
# choices=DECIMAL_CHOICES)
valid_kwargs = set((
'read_only', 'write_only',
'required', 'default', 'initial', 'source',
@ -1144,11 +1151,13 @@ class ModelSerializer(Serializer):
# matched to the model field.
field_kwargs.pop('model_field', None)
if not issubclass(field_class, CharField) and not issubclass(field_class, ChoiceField):
if not issubclass(field_class, CharField) and not issubclass(
field_class, ChoiceField):
# `allow_blank` is only valid for textual fields.
field_kwargs.pop('allow_blank', None)
if postgres_fields and isinstance(model_field, postgres_fields.ArrayField):
if postgres_fields and isinstance(model_field,
postgres_fields.ArrayField):
# Populate the `child` argument on `ListField` instances generated
# for the PostgrSQL specfic `ArrayField`.
child_model_field = model_field.base_field
@ -1167,7 +1176,9 @@ class ModelSerializer(Serializer):
field_kwargs = get_relation_kwargs(field_name, relation_info)
to_field = field_kwargs.pop('to_field', None)
if to_field and not relation_info.reverse and not relation_info.related_model._meta.get_field(to_field).primary_key:
if to_field and not relation_info.reverse and not \
relation_info.related_model._meta.get_field(
to_field).primary_key:
field_kwargs['slug_field'] = to_field
field_class = self.serializer_related_to_field
@ -1181,6 +1192,7 @@ class ModelSerializer(Serializer):
"""
Create nested fields for forward and reverse relationships.
"""
class NestedSerializer(ModelSerializer):
class Meta:
model = relation_info.related_model
@ -1236,7 +1248,8 @@ class ModelSerializer(Serializer):
kwargs.pop('required')
if extra_kwargs.get('read_only', kwargs.get('read_only', False)):
extra_kwargs.pop('required', None) # Read only fields should always omit the 'required' argument.
# Read only fields should always omit the 'required' argument.
extra_kwargs.pop('required', None)
kwargs.update(extra_kwargs)
@ -1249,7 +1262,7 @@ class ModelSerializer(Serializer):
Return a dictionary mapping field names to a dictionary of
additional keyword arguments.
"""
extra_kwargs = copy.deepcopy(getattr(self.Meta, 'extra_kwargs', {}))
extra_kwargs = deepcopy(getattr(self.Meta, 'extra_kwargs', {}))
read_only_fields = getattr(self.Meta, 'read_only_fields', None)
if read_only_fields is not None:
@ -1265,7 +1278,8 @@ class ModelSerializer(Serializer):
return extra_kwargs
def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs):
def get_uniqueness_extra_kwargs(self, field_names, declared_fields,
extra_kwargs):
"""
Return any additional field options that need to be included as a
result of uniqueness constraints on the model. This is returned as
@ -1288,7 +1302,8 @@ class ModelSerializer(Serializer):
for model_field in model_fields.values():
# Include each of the `unique_for_*` field names.
unique_constraint_names |= {model_field.unique_for_date, model_field.unique_for_month,
unique_constraint_names |= {model_field.unique_for_date,
model_field.unique_for_month,
model_field.unique_for_year}
unique_constraint_names -= {None}
@ -1302,13 +1317,15 @@ class ModelSerializer(Serializer):
# Now we have all the field names that have uniqueness constraints
# applied, we can add the extra 'required=...' or 'default=...'
# arguments that are appropriate to these fields, or add a `HiddenField` for it.
# arguments that are appropriate to these fields, or add a
# `HiddenField` for it.
hidden_fields = {}
uniqueness_extra_kwargs = {}
for unique_constraint_name in unique_constraint_names:
# Get the model field that is referred too.
unique_constraint_field = model._meta.get_field(unique_constraint_name)
unique_constraint_field = model._meta.get_field(
unique_constraint_name)
if getattr(unique_constraint_field, 'auto_now_add', None):
default = CreateOnlyDefault(timezone.now)
@ -1322,14 +1339,17 @@ class ModelSerializer(Serializer):
if unique_constraint_name in model_fields:
# The corresponding field is present in the serializer
if default is empty:
uniqueness_extra_kwargs[unique_constraint_name] = {'required': True}
uniqueness_extra_kwargs[unique_constraint_name] = {
'required': True}
else:
uniqueness_extra_kwargs[unique_constraint_name] = {'default': default}
uniqueness_extra_kwargs[unique_constraint_name] = {
'default': default}
elif default is not empty:
# The corresponding field is not present in the
# serializer. We have a default to use for it, so
# add in a hidden field that populates it.
hidden_fields[unique_constraint_name] = HiddenField(default=default)
hidden_fields[unique_constraint_name] = HiddenField(
default=default)
# Update `extra_kwargs` with any new options.
for key, value in uniqueness_extra_kwargs.items():
@ -1393,7 +1413,8 @@ class ModelSerializer(Serializer):
def get_unique_together_validators(self):
"""
Determine a default set of validators for any unique_together constraints.
Determine a default set of validators for any unique_together
constraints.
"""
model_class_inheritance_tree = (
[self.Meta.model] +
@ -1404,10 +1425,8 @@ class ModelSerializer(Serializer):
# which may map onto a model field. Any dotted field name lookups
# cannot map to a field, and must be a traversal, so we're not
# including those.
field_names = {
field.source for field in self._writable_fields
if (field.source != '*') and ('.' not in field.source)
}
field_names = {field.source for field in self._writable_fields
if (field.source != '*') and ('.' not in field.source)}
# Note that we make sure to check `unique_together` both on the
# base model class, but also on any parent classes.
@ -1469,14 +1488,17 @@ if hasattr(models, 'UUIDField'):
# IPAddressField is deprecated in Django
if hasattr(models, 'IPAddressField'):
ModelSerializer.serializer_field_mapping[models.IPAddressField] = IPAddressField
ModelSerializer.serializer_field_mapping[
models.IPAddressField] = IPAddressField
if postgres_fields:
class CharMappingField(DictField):
child = CharField(allow_blank=True)
ModelSerializer.serializer_field_mapping[postgres_fields.HStoreField] = CharMappingField
ModelSerializer.serializer_field_mapping[postgres_fields.ArrayField] = ListField
ModelSerializer.serializer_field_mapping[postgres_fields.HStoreField] = \
CharMappingField
ModelSerializer.serializer_field_mapping[postgres_fields.ArrayField] = \
ListField
class HyperlinkedModelSerializer(ModelSerializer):
@ -1505,6 +1527,7 @@ class HyperlinkedModelSerializer(ModelSerializer):
"""
Create nested fields for forward and reverse relationships.
"""
class NestedSerializer(HyperlinkedModelSerializer):
class Meta:
model = relation_info.related_model

View File

@ -33,10 +33,8 @@ class TestManyPostView(TestCase):
for item in items:
BasicModel(text=item).save()
self.objects = BasicModel.objects
self.data = [
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
]
self.data = [{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()]
self.view = ManyPostView.as_view()
def test_post_many_post_view(self):
@ -44,7 +42,8 @@ class TestManyPostView(TestCase):
POST request to a view that returns a list of objects should
still successfully return the browsable API with a rendered form.
Regression test for https://github.com/tomchristie/django-rest-framework/pull/3164
Regression test for
https://github.com/tomchristie/django-rest-framework/pull/3164
"""
data = {}
request = factory.post('/', data, format='json')

View File

@ -45,6 +45,7 @@ class NonAtomicAPIExceptionView(APIView):
BasicModel.objects.all()
raise Http404
urlpatterns = (
url(r'^$', NonAtomicAPIExceptionView.as_view()),
)
@ -89,7 +90,8 @@ class DBTransactionErrorTests(TestCase):
Transaction is eventually managed by outer-most transaction atomic
block. DRF do not try to interfere here.
We let django deal with the transaction when it will catch the Exception.
We let django deal with the transaction when it will catch the
Exception.
"""
request = factory.post('/')
with self.assertNumQueries(3):

File diff suppressed because it is too large Load Diff

View File

@ -85,10 +85,8 @@ class TestRootView(TestCase):
for item in items:
BasicModel(text=item).save()
self.objects = BasicModel.objects
self.data = [
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
]
self.data = [{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()]
self.view = RootView.as_view()
def test_get_root_view(self):
@ -122,8 +120,10 @@ class TestRootView(TestCase):
request = factory.put('/', data, format='json')
with self.assertNumQueries(0):
response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
self.assertEqual(response.data, {"detail": 'Method "PUT" not allowed.'})
self.assertEqual(response.status_code,
status.HTTP_405_METHOD_NOT_ALLOWED)
self.assertEqual(response.data,
{"detail": 'Method "PUT" not allowed.'})
def test_delete_root_view(self):
"""
@ -132,8 +132,10 @@ class TestRootView(TestCase):
request = factory.delete('/')
with self.assertNumQueries(0):
response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
self.assertEqual(response.data, {"detail": 'Method "DELETE" not allowed.'})
self.assertEqual(response.status_code,
status.HTTP_405_METHOD_NOT_ALLOWED)
self.assertEqual(response.data,
{"detail": 'Method "DELETE" not allowed.'})
def test_post_cannot_set_id(self):
"""
@ -156,7 +158,8 @@ class TestRootView(TestCase):
request = factory.post('/', data, HTTP_ACCEPT='text/html')
response = self.view(request).render()
expected_error = '<span class="help-block">Ensure this field has no more than 100 characters.</span>'
self.assertIn(expected_error, response.rendered_content.decode('utf-8'))
self.assertIn(expected_error,
response.rendered_content.decode('utf-8'))
EXPECTED_QUERIES_FOR_PUT = 2
@ -171,10 +174,8 @@ class TestInstanceView(TestCase):
for item in items:
BasicModel(text=item).save()
self.objects = BasicModel.objects.exclude(text='filtered out')
self.data = [
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
]
self.data = [{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()]
self.view = InstanceView.as_view()
self.slug_based_view = SlugBasedInstanceView.as_view()
@ -196,8 +197,10 @@ class TestInstanceView(TestCase):
request = factory.post('/', data, format='json')
with self.assertNumQueries(0):
response = self.view(request).render()
self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED)
self.assertEqual(response.data, {"detail": 'Method "POST" not allowed.'})
self.assertEqual(response.status_code,
status.HTTP_405_METHOD_NOT_ALLOWED)
self.assertEqual(response.data,
{"detail": 'Method "POST" not allowed.'})
def test_put_instance_view(self):
"""
@ -280,7 +283,8 @@ class TestInstanceView(TestCase):
"""
data = {'text': 'foo'}
filtered_out_pk = BasicModel.objects.filter(text='filtered out')[0].pk
request = factory.put('/{0}'.format(filtered_out_pk), data, format='json')
request = factory.put('/{0}'.format(filtered_out_pk), data,
format='json')
response = self.view(request, pk=filtered_out_pk).render()
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
@ -303,7 +307,8 @@ class TestInstanceView(TestCase):
request = factory.put('/', data, HTTP_ACCEPT='text/html')
response = self.view(request, pk=1).render()
expected_error = '<span class="help-block">Ensure this field has no more than 100 characters.</span>'
self.assertIn(expected_error, response.rendered_content.decode('utf-8'))
self.assertIn(expected_error,
response.rendered_content.decode('utf-8'))
class TestFKInstanceView(TestCase):
@ -318,10 +323,8 @@ class TestFKInstanceView(TestCase):
ForeignKeySource(name='source_' + item, target=t).save()
self.objects = ForeignKeySource.objects
self.data = [
{'id': obj.id, 'name': obj.name}
for obj in self.objects.all()
]
self.data = [{'id': obj.id, 'name': obj.name}
for obj in self.objects.all()]
self.view = FKInstanceView.as_view()
@ -339,10 +342,8 @@ class TestOverriddenGetObject(TestCase):
for item in items:
BasicModel(text=item).save()
self.objects = BasicModel.objects
self.data = [
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
]
self.data = [{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()]
class OverriddenGetObjectView(generics.RetrieveUpdateDestroyAPIView):
"""
@ -477,10 +478,8 @@ class TestFilterBackendAppliedToViews(TestCase):
for item in items:
BasicModel(text=item).save()
self.objects = BasicModel.objects
self.data = [
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
]
self.data = [{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()]
def test_get_root_view_filters_by_name_with_filter_backend(self):
"""
@ -493,7 +492,8 @@ class TestFilterBackendAppliedToViews(TestCase):
self.assertEqual(len(response.data), 1)
self.assertEqual(response.data, [{'id': 1, 'text': 'foo'}])
def test_get_root_view_filters_out_all_models_with_exclusive_filter_backend(self):
def test_get_root_view_filters_out_all_models_with_exclusive_filter_backend(
self):
"""
GET requests to ListCreateAPIView should return empty list when all models are filtered out.
"""
@ -507,17 +507,20 @@ class TestFilterBackendAppliedToViews(TestCase):
"""
GET requests to RetrieveUpdateDestroyAPIView should raise 404 when model filtered out.
"""
instance_view = InstanceView.as_view(filter_backends=(ExclusiveFilterBackend,))
instance_view = InstanceView.as_view(
filter_backends=(ExclusiveFilterBackend,))
request = factory.get('/1')
response = instance_view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertEqual(response.data, {'detail': 'Not found.'})
def test_get_instance_view_will_return_single_object_when_filter_does_not_exclude_it(self):
def test_get_instance_view_will_return_single_object_when_filter_does_not_exclude_it(
self):
"""
GET requests to RetrieveUpdateDestroyAPIView should return a single object when not excluded
"""
instance_view = InstanceView.as_view(filter_backends=(InclusiveFilterBackend,))
instance_view = InstanceView.as_view(
filter_backends=(InclusiveFilterBackend,))
request = factory.get('/1')
response = instance_view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)

View File

@ -4,11 +4,14 @@ from django.core.validators import MaxValueValidator, MinValueValidator
from django.db import models
from django.test import TestCase
from rest_framework import (
exceptions, metadata, serializers, status, versioning, views
from rest_framework import exceptions, metadata, status, versioning, views
from rest_framework.fields import (
CharField, ChoiceField, IntegerField, ListField, NullBooleanField
)
from rest_framework.relations import PrimaryKeyRelatedField, RelatedField
from rest_framework.renderers import BrowsableAPIRenderer
from rest_framework.request import Request
from rest_framework.serializers import ModelSerializer, Serializer
from rest_framework.test import APIRequestFactory
from .models import BasicModel
@ -21,6 +24,7 @@ class TestMetadata:
"""
OPTIONS requests to views should return a valid 200 response.
"""
class ExampleView(views.APIView):
"""Example view."""
pass
@ -46,8 +50,9 @@ class TestMetadata:
def test_none_metadata(self):
"""
OPTIONS requests to views where `metadata_class = None` should raise
a MethodNotAllowed exception, which will result in an HTTP 405 response.
a MethodNotAllowed exception, which will result in an HTTP 405 response
"""
class ExampleView(views.APIView):
metadata_class = None
@ -61,27 +66,29 @@ class TestMetadata:
On generic views OPTIONS should return an 'actions' key with metadata
on the fields that may be supplied to PUT and POST requests.
"""
class NestedField(serializers.Serializer):
a = serializers.IntegerField()
b = serializers.IntegerField()
class ExampleSerializer(serializers.Serializer):
choice_field = serializers.ChoiceField(['red', 'green', 'blue'])
integer_field = serializers.IntegerField(
class NestedField(Serializer):
a = IntegerField()
b = IntegerField()
class ExampleSerializer(Serializer):
choice_field = ChoiceField(['red', 'green', 'blue'])
integer_field = IntegerField(
min_value=1, max_value=1000
)
char_field = serializers.CharField(
char_field = CharField(
required=False, min_length=3, max_length=40
)
list_field = serializers.ListField(
child=serializers.ListField(
child=serializers.IntegerField()
list_field = ListField(
child=ListField(
child=IntegerField()
)
)
nested_field = NestedField()
class ExampleView(views.APIView):
"""Example view."""
def post(self, request):
pass
@ -179,13 +186,15 @@ class TestMetadata:
If a user does not have global permissions on an action, then any
metadata associated with it should not be included in OPTION responses.
"""
class ExampleSerializer(serializers.Serializer):
choice_field = serializers.ChoiceField(['red', 'green', 'blue'])
integer_field = serializers.IntegerField(max_value=10)
char_field = serializers.CharField(required=False)
class ExampleSerializer(Serializer):
choice_field = ChoiceField(['red', 'green', 'blue'])
integer_field = IntegerField(max_value=10)
char_field = CharField(required=False)
class ExampleView(views.APIView):
"""Example view."""
def post(self, request):
pass
@ -209,13 +218,15 @@ class TestMetadata:
If a user does not have object permissions on an action, then any
metadata associated with it should not be included in OPTION responses.
"""
class ExampleSerializer(serializers.Serializer):
choice_field = serializers.ChoiceField(['red', 'green', 'blue'])
integer_field = serializers.IntegerField(max_value=10)
char_field = serializers.CharField(required=False)
class ExampleSerializer(Serializer):
choice_field = ChoiceField(['red', 'green', 'blue'])
integer_field = IntegerField(max_value=10)
char_field = CharField(required=False)
class ExampleView(views.APIView):
"""Example view."""
def post(self, request):
pass
@ -243,7 +254,7 @@ class TestMetadata:
def get_serializer(self):
assert hasattr(self.request, 'version')
return serializers.Serializer()
return Serializer()
view = ExampleView.as_view()
view(request=request)
@ -257,7 +268,7 @@ class TestMetadata:
def get_serializer(self):
assert hasattr(self.request, 'versioning_scheme')
return serializers.Serializer()
return Serializer()
scheme = versioning.QueryParameterVersioning
view = ExampleView.as_view(versioning_class=scheme)
@ -267,7 +278,7 @@ class TestMetadata:
class TestSimpleMetadataFieldInfo(TestCase):
def test_null_boolean_field_info_type(self):
options = metadata.SimpleMetadata()
field_info = options.get_field_info(serializers.NullBooleanField())
field_info = options.get_field_info(NullBooleanField())
self.assertEqual(field_info['type'], 'boolean')
def test_related_field_choices(self):
@ -275,7 +286,7 @@ class TestSimpleMetadataFieldInfo(TestCase):
BasicModel.objects.create()
with self.assertNumQueries(0):
field_info = options.get_field_info(
serializers.RelatedField(queryset=BasicModel.objects.all())
RelatedField(queryset=BasicModel.objects.all())
)
self.assertNotIn('choices', field_info)
@ -287,16 +298,18 @@ class TestModelSerializerMetadata(TestCase):
on the fields that may be supplied to PUT and POST requests. It should
not fail when a read_only PrimaryKeyRelatedField is present
"""
class Parent(models.Model):
integer_field = models.IntegerField(validators=[MinValueValidator(1), MaxValueValidator(1000)])
integer_field = models.IntegerField(
validators=[MinValueValidator(1), MaxValueValidator(1000)])
children = models.ManyToManyField('Child')
name = models.CharField(max_length=100, blank=True, null=True)
class Child(models.Model):
name = models.CharField(max_length=100)
class ExampleSerializer(serializers.ModelSerializer):
children = serializers.PrimaryKeyRelatedField(read_only=True, many=True)
class ExampleSerializer(ModelSerializer):
children = PrimaryKeyRelatedField(read_only=True, many=True)
class Meta:
model = Parent
@ -304,6 +317,7 @@ class TestModelSerializerMetadata(TestCase):
class ExampleView(views.APIView):
"""Example view."""
def post(self, request):
pass

View File

@ -17,7 +17,8 @@ class ChildModel(ParentModel):
class AssociatedModel(RESTFrameworkModel):
ref = models.OneToOneField(ParentModel, primary_key=True, on_delete=models.CASCADE)
ref = models.OneToOneField(ParentModel, primary_key=True,
on_delete=models.CASCADE)
name = models.CharField(max_length=100)
@ -36,7 +37,6 @@ class AssociatedModelSerializer(serializers.ModelSerializer):
# Tests
class InheritedModelSerializationTests(TestCase):
def test_multitable_inherited_model_fields_as_expected(self):
"""
Assert that the parent pointer field is not included in the fields

View File

@ -5,9 +5,6 @@ from django.test import TestCase
from rest_framework import serializers
from tests.models import RESTFrameworkModel
# Models
from tests.test_multitable_inheritance import ChildModel
@ -26,7 +23,6 @@ class DerivedModelSerializer(serializers.ModelSerializer):
class ChildAssociatedModelSerializer(serializers.ModelSerializer):
class Meta:
model = ChildAssociatedModel
fields = ['id', 'child_name']
@ -34,7 +30,6 @@ class ChildAssociatedModelSerializer(serializers.ModelSerializer):
# Tests
class InheritedModelSerializationTests(TestCase):
def test_multitable_inherited_model_fields_as_expected(self):
"""
Assert that the parent pointer field is not included in the fields

View File

@ -65,27 +65,33 @@ empty_list_view = EmptyListView.as_view()
def basic_auth_header(username, password):
credentials = ('%s:%s' % (username, password))
base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING)
base64_credentials = base64.b64encode(
credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING)
return 'Basic %s' % base64_credentials
class ModelPermissionsIntegrationTests(TestCase):
def setUp(self):
User.objects.create_user('disallowed', 'disallowed@example.com', 'password')
user = User.objects.create_user('permitted', 'permitted@example.com', 'password')
User.objects.create_user('disallowed', 'disallowed@example.com',
'password')
user = User.objects.create_user('permitted', 'permitted@example.com',
'password')
set_many(user, 'user_permissions', [
Permission.objects.get(codename='add_basicmodel'),
Permission.objects.get(codename='change_basicmodel'),
Permission.objects.get(codename='delete_basicmodel')
])
user = User.objects.create_user('updateonly', 'updateonly@example.com', 'password')
user = User.objects.create_user('updateonly', 'updateonly@example.com',
'password')
set_many(user, 'user_permissions', [
Permission.objects.get(codename='change_basicmodel'),
])
self.permitted_credentials = basic_auth_header('permitted', 'password')
self.disallowed_credentials = basic_auth_header('disallowed', 'password')
self.updateonly_credentials = basic_auth_header('updateonly', 'password')
self.disallowed_credentials = basic_auth_header('disallowed',
'password')
self.updateonly_credentials = basic_auth_header('updateonly',
'password')
BasicModel(text='foo').save()
@ -120,7 +126,8 @@ class ModelPermissionsIntegrationTests(TestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_has_delete_permissions(self):
request = factory.delete('/1', HTTP_AUTHORIZATION=self.permitted_credentials)
request = factory.delete('/1',
HTTP_AUTHORIZATION=self.permitted_credentials)
response = instance_view(request, pk=1)
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
@ -137,7 +144,8 @@ class ModelPermissionsIntegrationTests(TestCase):
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_does_not_have_delete_permissions(self):
request = factory.delete('/1', HTTP_AUTHORIZATION=self.disallowed_credentials)
request = factory.delete(
'/1', HTTP_AUTHORIZATION=self.disallowed_credentials)
response = instance_view(request, pk=1)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
@ -196,7 +204,8 @@ class ModelPermissionsIntegrationTests(TestCase):
self.assertEqual(list(response.data['actions'].keys()), ['PUT'])
def test_empty_view_does_not_assert(self):
request = factory.get('/1', HTTP_AUTHORIZATION=self.permitted_credentials)
request = factory.get('/1',
HTTP_AUTHORIZATION=self.permitted_credentials)
response = empty_list_view(request, pk=1)
self.assertEqual(response.status_code, status.HTTP_200_OK)
@ -237,6 +246,7 @@ class ObjectPermissionInstanceView(generics.RetrieveUpdateDestroyAPIView):
authentication_classes = [authentication.BasicAuthentication]
permission_classes = [ViewObjectPermissions]
object_permissions_view = ObjectPermissionInstanceView.as_view()
@ -246,10 +256,12 @@ class ObjectPermissionListView(generics.ListAPIView):
authentication_classes = [authentication.BasicAuthentication]
permission_classes = [ViewObjectPermissions]
object_permissions_list_view = ObjectPermissionListView.as_view()
class GetQuerysetObjectPermissionInstanceView(generics.RetrieveUpdateDestroyAPIView):
class GetQuerysetObjectPermissionInstanceView(generics.
RetrieveUpdateDestroyAPIView):
serializer_class = BasicPermSerializer
authentication_classes = [authentication.BasicAuthentication]
permission_classes = [ViewObjectPermissions]
@ -258,7 +270,8 @@ class GetQuerysetObjectPermissionInstanceView(generics.RetrieveUpdateDestroyAPIV
return BasicPermModel.objects.all()
get_queryset_object_permissions_view = GetQuerysetObjectPermissionInstanceView.as_view()
get_queryset_object_permissions_view = \
GetQuerysetObjectPermissionInstanceView.as_view()
@unittest.skipUnless(guardian, 'django-guardian not installed')
@ -266,16 +279,20 @@ class ObjectPermissionsIntegrationTests(TestCase):
"""
Integration tests for the object level permissions API.
"""
def setUp(self):
from guardian.shortcuts import assign_perm
# create users
create = User.objects.create_user
users = {
'fullaccess': create('fullaccess', 'fullaccess@example.com', 'password'),
'fullaccess': create('fullaccess', 'fullaccess@example.com',
'password'),
'readonly': create('readonly', 'readonly@example.com', 'password'),
'writeonly': create('writeonly', 'writeonly@example.com', 'password'),
'deleteonly': create('deleteonly', 'deleteonly@example.com', 'password'),
'writeonly': create('writeonly', 'writeonly@example.com',
'password'),
'deleteonly': create('deleteonly', 'deleteonly@example.com',
'password'),
}
# give everyone model level permissions, as we are not testing those
@ -310,16 +327,19 @@ class ObjectPermissionsIntegrationTests(TestCase):
self.credentials = {}
for user in users.values():
self.credentials[user.username] = basic_auth_header(user.username, 'password')
self.credentials[user.username] = basic_auth_header(user.username,
'password')
# Delete
def test_can_delete_permissions(self):
request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['deleteonly'])
request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials[
'deleteonly'])
response = object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
def test_cannot_delete_permissions(self):
request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials['readonly'])
request = factory.delete('/1', HTTP_AUTHORIZATION=self.credentials[
'readonly'])
response = object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
@ -351,12 +371,14 @@ class ObjectPermissionsIntegrationTests(TestCase):
# Read
def test_can_read_permissions(self):
request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['readonly'])
request = factory.get('/1',
HTTP_AUTHORIZATION=self.credentials['readonly'])
response = object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_cannot_read_permissions(self):
request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['writeonly'])
request = factory.get('/1',
HTTP_AUTHORIZATION=self.credentials['writeonly'])
response = object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
@ -365,21 +387,26 @@ class ObjectPermissionsIntegrationTests(TestCase):
same as ``test_can_read_permissions`` but with a view
that rely on ``.get_queryset()`` instead of ``.queryset``.
"""
request = factory.get('/1', HTTP_AUTHORIZATION=self.credentials['readonly'])
request = factory.get('/1',
HTTP_AUTHORIZATION=self.credentials['readonly'])
response = get_queryset_object_permissions_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK)
# Read list
def test_can_read_list_permissions(self):
request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['readonly'])
object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,)
request = factory.get('/',
HTTP_AUTHORIZATION=self.credentials['readonly'])
object_permissions_list_view.cls.filter_backends = (
DjangoObjectPermissionsFilter,)
response = object_permissions_list_view(request)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data[0].get('id'), 1)
def test_cannot_read_list_permissions(self):
request = factory.get('/', HTTP_AUTHORIZATION=self.credentials['writeonly'])
object_permissions_list_view.cls.filter_backends = (DjangoObjectPermissionsFilter,)
request = factory.get('/',
HTTP_AUTHORIZATION=self.credentials['writeonly'])
object_permissions_list_view.cls.filter_backends = (
DjangoObjectPermissionsFilter,)
response = object_permissions_list_view(request)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertListEqual(response.data, [])
@ -429,6 +456,7 @@ class DeniedObjectView(PermissionInstanceView):
class DeniedObjectViewWithDetail(PermissionInstanceView):
permission_classes = (BasicObjectPermWithDetail,)
denied_view = DeniedView.as_view()
denied_view_with_detail = DeniedViewWithDetail.as_view()
@ -441,31 +469,33 @@ denied_object_view_with_detail = DeniedObjectViewWithDetail.as_view()
class CustomPermissionsTests(TestCase):
def setUp(self):
BasicModel(text='foo').save()
User.objects.create_user('username', 'username@example.com', 'password')
User.objects.create_user('username', 'username@example.com',
'password')
credentials = basic_auth_header('username', 'password')
self.request = factory.get('/1', format='json', HTTP_AUTHORIZATION=credentials)
self.request = factory.get('/1', format='json',
HTTP_AUTHORIZATION=credentials)
self.custom_message = 'Custom: You cannot access this resource'
def test_permission_denied(self):
response = denied_view(self.request, pk=1)
detail = response.data.get('detail')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertNotEqual(detail, self.custom_message)
response = denied_view(self.request, pk=1)
detail = response.data.get('detail')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertNotEqual(detail, self.custom_message)
def test_permission_denied_with_custom_detail(self):
response = denied_view_with_detail(self.request, pk=1)
detail = response.data.get('detail')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(detail, self.custom_message)
response = denied_view_with_detail(self.request, pk=1)
detail = response.data.get('detail')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(detail, self.custom_message)
def test_permission_denied_for_object(self):
response = denied_object_view(self.request, pk=1)
detail = response.data.get('detail')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertNotEqual(detail, self.custom_message)
response = denied_object_view(self.request, pk=1)
detail = response.data.get('detail')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertNotEqual(detail, self.custom_message)
def test_permission_denied_for_object_with_custom_detail(self):
response = denied_object_view_with_detail(self.request, pk=1)
detail = response.data.get('detail')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(detail, self.custom_message)
response = denied_object_view_with_detail(self.request, pk=1)
detail = response.data.get('detail')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(detail, self.custom_message)

View File

@ -1,11 +1,14 @@
import uuid
import pytest
from django.core.exceptions import ImproperlyConfigured
from django.utils.datastructures import MultiValueDict
from rest_framework import serializers
from rest_framework.fields import empty
from rest_framework.fields import UUIDField, ValidationError, empty
from rest_framework.relations import (
Hyperlink, HyperlinkedIdentityField, HyperlinkedRelatedField,
ImproperlyConfigured, PrimaryKeyRelatedField, SlugRelatedField,
StringRelatedField
)
from rest_framework.test import APISimpleTestCase
from .utils import (
@ -16,7 +19,7 @@ from .utils import (
class TestStringRelatedField(APISimpleTestCase):
def setUp(self):
self.instance = MockObject(pk=1, name='foo')
self.field = serializers.StringRelatedField()
self.field = StringRelatedField()
def test_string_related_representation(self):
representation = self.field.to_representation(self.instance)
@ -31,20 +34,20 @@ class TestPrimaryKeyRelatedField(APISimpleTestCase):
MockObject(pk=3, name='baz')
])
self.instance = self.queryset.items[2]
self.field = serializers.PrimaryKeyRelatedField(queryset=self.queryset)
self.field = PrimaryKeyRelatedField(queryset=self.queryset)
def test_pk_related_lookup_exists(self):
instance = self.field.to_internal_value(self.instance.pk)
assert instance is self.instance
def test_pk_related_lookup_does_not_exist(self):
with pytest.raises(serializers.ValidationError) as excinfo:
with pytest.raises(ValidationError) as excinfo:
self.field.to_internal_value(4)
msg = excinfo.value.detail[0]
assert msg == 'Invalid pk "4" - object does not exist.'
def test_pk_related_lookup_invalid_type(self):
with pytest.raises(serializers.ValidationError) as excinfo:
with pytest.raises(ValidationError) as excinfo:
self.field.to_internal_value(BadType())
msg = excinfo.value.detail[0]
assert msg == 'Incorrect type. Expected pk value, received BadType.'
@ -54,7 +57,7 @@ class TestPrimaryKeyRelatedField(APISimpleTestCase):
assert representation == self.instance.pk
def test_explicit_many_false(self):
field = serializers.PrimaryKeyRelatedField(queryset=self.queryset, many=False)
field = PrimaryKeyRelatedField(queryset=self.queryset, many=False)
instance = field.to_internal_value(self.instance.pk)
assert instance is self.instance
@ -67,9 +70,9 @@ class TestProxiedPrimaryKeyRelatedField(APISimpleTestCase):
MockObject(pk=uuid.UUID(int=2), name='baz')
])
self.instance = self.queryset.items[2]
self.field = serializers.PrimaryKeyRelatedField(
self.field = PrimaryKeyRelatedField(
queryset=self.queryset,
pk_field=serializers.UUIDField(format='int')
pk_field=UUIDField(format='int')
)
def test_pk_related_lookup_exists(self):
@ -77,7 +80,7 @@ class TestProxiedPrimaryKeyRelatedField(APISimpleTestCase):
assert instance is self.instance
def test_pk_related_lookup_does_not_exist(self):
with pytest.raises(serializers.ValidationError) as excinfo:
with pytest.raises(ValidationError) as excinfo:
self.field.to_internal_value(4)
msg = excinfo.value.detail[0]
assert msg == 'Invalid pk "00000000-0000-0000-0000-000000000004" - object does not exist.'
@ -89,7 +92,7 @@ class TestProxiedPrimaryKeyRelatedField(APISimpleTestCase):
class TestHyperlinkedRelatedField(APISimpleTestCase):
def setUp(self):
self.field = serializers.HyperlinkedRelatedField(
self.field = HyperlinkedRelatedField(
view_name='example', read_only=True)
self.field.reverse = mock_reverse
self.field._context = {'request': True}
@ -102,7 +105,7 @@ class TestHyperlinkedRelatedField(APISimpleTestCase):
class TestHyperlinkedIdentityField(APISimpleTestCase):
def setUp(self):
self.instance = MockObject(pk=1, name='foo')
self.field = serializers.HyperlinkedIdentityField(view_name='example')
self.field = HyperlinkedIdentityField(view_name='example')
self.field.reverse = mock_reverse
self.field._context = {'request': True}
@ -135,14 +138,15 @@ class TestHyperlinkedIdentityFieldWithFormat(APISimpleTestCase):
Tests for a hyperlinked identity field that has a `format` set,
which enforces that alternate formats are never linked too.
Eg. If your API includes some endpoints that accept both `.xml` and `.json`,
but other endpoints that only accept `.json`, we allow for hyperlinked
Eg. If your API includes some endpoints that accept both `.xml` and `.json`
but other endpoints only accept `.json`, we allow for hyperlinked
relationships that enforce only a single suffix type.
"""
def setUp(self):
self.instance = MockObject(pk=1, name='foo')
self.field = serializers.HyperlinkedIdentityField(view_name='example', format='json')
self.field = HyperlinkedIdentityField(view_name='example',
format='json')
self.field.reverse = mock_reverse
self.field._context = {'request': True}
@ -164,7 +168,7 @@ class TestSlugRelatedField(APISimpleTestCase):
MockObject(pk=3, name='baz')
])
self.instance = self.queryset.items[2]
self.field = serializers.SlugRelatedField(
self.field = SlugRelatedField(
slug_field='name', queryset=self.queryset
)
@ -173,13 +177,13 @@ class TestSlugRelatedField(APISimpleTestCase):
assert instance is self.instance
def test_slug_related_lookup_does_not_exist(self):
with pytest.raises(serializers.ValidationError) as excinfo:
with pytest.raises(ValidationError) as excinfo:
self.field.to_internal_value('doesnotexist')
msg = excinfo.value.detail[0]
assert msg == 'Object with name=doesnotexist does not exist.'
def test_slug_related_lookup_invalid_type(self):
with pytest.raises(serializers.ValidationError) as excinfo:
with pytest.raises(ValidationError) as excinfo:
self.field.to_internal_value(BadType())
msg = excinfo.value.detail[0]
assert msg == 'Invalid value.'
@ -191,7 +195,7 @@ class TestSlugRelatedField(APISimpleTestCase):
def test_overriding_get_queryset(self):
qs = self.queryset
class NoQuerySetSlugRelatedField(serializers.SlugRelatedField):
class NoQuerySetSlugRelatedField(SlugRelatedField):
def get_queryset(self):
return qs
@ -202,7 +206,7 @@ class TestSlugRelatedField(APISimpleTestCase):
class TestManyRelatedField(APISimpleTestCase):
def setUp(self):
self.instance = MockObject(pk=1, name='foo')
self.field = serializers.StringRelatedField(many=True)
self.field = StringRelatedField(many=True)
self.field.field_name = 'foo'
def test_get_value_regular_dictionary_full(self):
@ -232,7 +236,7 @@ class TestManyRelatedField(APISimpleTestCase):
class TestHyperlink:
def setup(self):
self.default_hyperlink = serializers.Hyperlink('http://example.com', 'test')
self.default_hyperlink = Hyperlink('http://example.com', 'test')
def test_can_be_pickled(self):
import pickle

View File

@ -8,7 +8,8 @@ from django.db import models
from django.test import TestCase
from django.utils.encoding import python_2_unicode_compatible
from rest_framework import serializers
from rest_framework.relations import StringRelatedField
from rest_framework.serializers import ModelSerializer
@python_2_unicode_compatible
@ -51,7 +52,8 @@ class Note(models.Model):
class TestGenericRelations(TestCase):
def setUp(self):
self.bookmark = Bookmark.objects.create(url='https://www.djangoproject.com/')
self.bookmark = Bookmark.objects.create(
url='https://www.djangoproject.com/')
Tag.objects.create(tagged_item=self.bookmark, tag='django')
Tag.objects.create(tagged_item=self.bookmark, tag='python')
self.note = Note.objects.create(text='Remember the milk')
@ -63,8 +65,8 @@ class TestGenericRelations(TestCase):
IE. A reverse generic relationship.
"""
class BookmarkSerializer(serializers.ModelSerializer):
tags = serializers.StringRelatedField(many=True)
class BookmarkSerializer(ModelSerializer):
tags = StringRelatedField(many=True)
class Meta:
model = Bookmark
@ -83,8 +85,8 @@ class TestGenericRelations(TestCase):
IE. A forward generic relationship.
"""
class TagSerializer(serializers.ModelSerializer):
tagged_item = serializers.StringRelatedField()
class TagSerializer(ModelSerializer):
tagged_item = StringRelatedField()
class Meta:
model = Tag

View File

@ -11,7 +11,8 @@ from tests.models import (
)
factory = APIRequestFactory()
request = factory.get('/') # Just to ensure we have a request in the serializer context
request = factory.get(
'/') # Just to ensure we have a request in the serializer context
def dummy_view(request, pk):
@ -20,13 +21,20 @@ def dummy_view(request, pk):
urlpatterns = [
url(r'^dummyurl/(?P<pk>[0-9]+)/$', dummy_view, name='dummy-url'),
url(r'^manytomanysource/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanysource-detail'),
url(r'^manytomanytarget/(?P<pk>[0-9]+)/$', dummy_view, name='manytomanytarget-detail'),
url(r'^foreignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeysource-detail'),
url(r'^foreignkeytarget/(?P<pk>[0-9]+)/$', dummy_view, name='foreignkeytarget-detail'),
url(r'^nullableforeignkeysource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableforeignkeysource-detail'),
url(r'^onetoonetarget/(?P<pk>[0-9]+)/$', dummy_view, name='onetoonetarget-detail'),
url(r'^nullableonetoonesource/(?P<pk>[0-9]+)/$', dummy_view, name='nullableonetoonesource-detail'),
url(r'^manytomanysource/(?P<pk>[0-9]+)/$', dummy_view,
name='manytomanysource-detail'),
url(r'^manytomanytarget/(?P<pk>[0-9]+)/$', dummy_view,
name='manytomanytarget-detail'),
url(r'^foreignkeysource/(?P<pk>[0-9]+)/$', dummy_view,
name='foreignkeysource-detail'),
url(r'^foreignkeytarget/(?P<pk>[0-9]+)/$', dummy_view,
name='foreignkeytarget-detail'),
url(r'^nullableforeignkeysource/(?P<pk>[0-9]+)/$', dummy_view,
name='nullableforeignkeysource-detail'),
url(r'^onetoonetarget/(?P<pk>[0-9]+)/$', dummy_view,
name='onetoonetarget-detail'),
url(r'^nullableonetoonesource/(?P<pk>[0-9]+)/$', dummy_view,
name='nullableonetoonesource-detail'),
]
@ -57,7 +65,8 @@ class ForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):
# Nullable ForeignKey
class NullableForeignKeySourceSerializer(serializers.HyperlinkedModelSerializer):
class NullableForeignKeySourceSerializer(serializers.
HyperlinkedModelSerializer):
class Meta:
model = NullableForeignKeySource
fields = ('url', 'name', 'target')
@ -84,83 +93,141 @@ class HyperlinkedManyToManyTests(TestCase):
def test_relative_hyperlinks(self):
queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': None})
serializer = ManyToManySourceSerializer(queryset, many=True,
context={'request': None})
expected = [
{'url': '/manytomanysource/1/', 'name': 'source-1', 'targets': ['/manytomanytarget/1/']},
{'url': '/manytomanysource/2/', 'name': 'source-2', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']},
{'url': '/manytomanysource/3/', 'name': 'source-3', 'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/', '/manytomanytarget/3/']}
{'url': '/manytomanysource/1/', 'name': 'source-1',
'targets': ['/manytomanytarget/1/']},
{'url': '/manytomanysource/2/', 'name': 'source-2',
'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/']},
{'url': '/manytomanysource/3/', 'name': 'source-3',
'targets': ['/manytomanytarget/1/', '/manytomanytarget/2/',
'/manytomanytarget/3/']}
]
with self.assertNumQueries(4):
self.assertEqual(serializer.data, expected)
def test_many_to_many_retrieve(self):
queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
serializer = ManyToManySourceSerializer(queryset, many=True,
context={'request': request})
expected = [
{'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']},
{'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
{'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
{'url': 'http://testserver/manytomanysource/1/',
'name': 'source-1',
'targets': ['http://testserver/manytomanytarget/1/']},
{'url': 'http://testserver/manytomanysource/2/',
'name': 'source-2',
'targets': ['http://testserver/manytomanytarget/1/',
'http://testserver/manytomanytarget/2/']},
{'url': 'http://testserver/manytomanysource/3/',
'name': 'source-3',
'targets': ['http://testserver/manytomanytarget/1/',
'http://testserver/manytomanytarget/2/',
'http://testserver/manytomanytarget/3/']}
]
with self.assertNumQueries(4):
self.assertEqual(serializer.data, expected)
def test_many_to_many_retrieve_prefetch_related(self):
queryset = ManyToManySource.objects.all().prefetch_related('targets')
serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
serializer = ManyToManySourceSerializer(queryset, many=True,
context={'request': request})
with self.assertNumQueries(2):
serializer.data
def test_reverse_many_to_many_retrieve(self):
queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
serializer = ManyToManyTargetSerializer(queryset, many=True,
context={'request': request})
expected = [
{'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}
{'url': 'http://testserver/manytomanytarget/1/',
'name': 'target-1',
'sources': ['http://testserver/manytomanysource/1/',
'http://testserver/manytomanysource/2/',
'http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/2/',
'name': 'target-2',
'sources': ['http://testserver/manytomanysource/2/',
'http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/3/',
'name': 'target-3',
'sources': ['http://testserver/manytomanysource/3/']}
]
with self.assertNumQueries(4):
self.assertEqual(serializer.data, expected)
def test_many_to_many_update(self):
data = {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
data = {'url': 'http://testserver/manytomanysource/1/',
'name': 'source-1',
'targets': ['http://testserver/manytomanytarget/1/',
'http://testserver/manytomanytarget/2/',
'http://testserver/manytomanytarget/3/']}
instance = ManyToManySource.objects.get(pk=1)
serializer = ManyToManySourceSerializer(instance, data=data, context={'request': request})
serializer = ManyToManySourceSerializer(instance, data=data,
context={'request': request})
self.assertTrue(serializer.is_valid())
serializer.save()
self.assertEqual(serializer.data, data)
# Ensure source 1 is updated, and everything else is as expected
queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
serializer = ManyToManySourceSerializer(queryset, many=True,
context={'request': request})
expected = [
{'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']},
{'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
{'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
{'url': 'http://testserver/manytomanysource/1/',
'name': 'source-1',
'targets': ['http://testserver/manytomanytarget/1/',
'http://testserver/manytomanytarget/2/',
'http://testserver/manytomanytarget/3/']},
{'url': 'http://testserver/manytomanysource/2/',
'name': 'source-2',
'targets': ['http://testserver/manytomanytarget/1/',
'http://testserver/manytomanytarget/2/']},
{'url': 'http://testserver/manytomanysource/3/',
'name': 'source-3',
'targets': ['http://testserver/manytomanytarget/1/',
'http://testserver/manytomanytarget/2/',
'http://testserver/manytomanytarget/3/']}
]
self.assertEqual(serializer.data, expected)
def test_reverse_many_to_many_update(self):
data = {'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']}
data = {'url': 'http://testserver/manytomanytarget/1/',
'name': 'target-1',
'sources': ['http://testserver/manytomanysource/1/']}
instance = ManyToManyTarget.objects.get(pk=1)
serializer = ManyToManyTargetSerializer(instance, data=data, context={'request': request})
serializer = ManyToManyTargetSerializer(instance, data=data,
context={'request': request})
self.assertTrue(serializer.is_valid())
serializer.save()
self.assertEqual(serializer.data, data)
# Ensure target 1 is updated, and everything else is as expected
queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
serializer = ManyToManyTargetSerializer(queryset, many=True,
context={'request': request})
expected = [
{'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/']},
{'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}
{'url': 'http://testserver/manytomanytarget/1/',
'name': 'target-1',
'sources': ['http://testserver/manytomanysource/1/']},
{'url': 'http://testserver/manytomanytarget/2/',
'name': 'target-2',
'sources': ['http://testserver/manytomanysource/2/',
'http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/3/',
'name': 'target-3',
'sources': ['http://testserver/manytomanysource/3/']}
]
self.assertEqual(serializer.data, expected)
def test_many_to_many_create(self):
data = {'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']}
serializer = ManyToManySourceSerializer(data=data, context={'request': request})
data = {'url': 'http://testserver/manytomanysource/4/',
'name': 'source-4',
'targets': ['http://testserver/manytomanytarget/1/',
'http://testserver/manytomanytarget/3/']}
serializer = ManyToManySourceSerializer(data=data,
context={'request': request})
self.assertTrue(serializer.is_valid())
obj = serializer.save()
self.assertEqual(serializer.data, data)
@ -168,18 +235,35 @@ class HyperlinkedManyToManyTests(TestCase):
# Ensure source 4 is added, and everything else is as expected
queryset = ManyToManySource.objects.all()
serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
serializer = ManyToManySourceSerializer(queryset, many=True,
context={'request': request})
expected = [
{'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/']},
{'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
{'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']},
{'url': 'http://testserver/manytomanysource/4/', 'name': 'source-4', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/3/']}
{'url': 'http://testserver/manytomanysource/1/',
'name': 'source-1',
'targets': ['http://testserver/manytomanytarget/1/']},
{'url': 'http://testserver/manytomanysource/2/',
'name': 'source-2',
'targets': ['http://testserver/manytomanytarget/1/',
'http://testserver/manytomanytarget/2/']},
{'url': 'http://testserver/manytomanysource/3/',
'name': 'source-3',
'targets': ['http://testserver/manytomanytarget/1/',
'http://testserver/manytomanytarget/2/',
'http://testserver/manytomanytarget/3/']},
{'url': 'http://testserver/manytomanysource/4/',
'name': 'source-4',
'targets': ['http://testserver/manytomanytarget/1/',
'http://testserver/manytomanytarget/3/']}
]
self.assertEqual(serializer.data, expected)
def test_reverse_many_to_many_create(self):
data = {'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']}
serializer = ManyToManyTargetSerializer(data=data, context={'request': request})
data = {'url': 'http://testserver/manytomanytarget/4/',
'name': 'target-4',
'sources': ['http://testserver/manytomanysource/1/',
'http://testserver/manytomanysource/3/']}
serializer = ManyToManyTargetSerializer(data=data,
context={'request': request})
self.assertTrue(serializer.is_valid())
obj = serializer.save()
self.assertEqual(serializer.data, data)
@ -187,12 +271,25 @@ class HyperlinkedManyToManyTests(TestCase):
# Ensure target 4 is added, and everything else is as expected
queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset, many=True, context={'request': request})
serializer = ManyToManyTargetSerializer(queryset, many=True,
context={'request': request})
expected = [
{'url': 'http://testserver/manytomanytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/4/', 'name': 'target-4', 'sources': ['http://testserver/manytomanysource/1/', 'http://testserver/manytomanysource/3/']}
{'url': 'http://testserver/manytomanytarget/1/',
'name': 'target-1',
'sources': ['http://testserver/manytomanysource/1/',
'http://testserver/manytomanysource/2/',
'http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/2/',
'name': 'target-2',
'sources': ['http://testserver/manytomanysource/2/',
'http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/3/',
'name': 'target-3',
'sources': ['http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/4/',
'name': 'target-4',
'sources': ['http://testserver/manytomanysource/1/',
'http://testserver/manytomanysource/3/']}
]
self.assertEqual(serializer.data, expected)
@ -210,62 +307,99 @@ class HyperlinkedForeignKeyTests(TestCase):
def test_foreign_key_retrieve(self):
queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
serializer = ForeignKeySourceSerializer(queryset, many=True,
context={'request': request})
expected = [
{'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}
{'url': 'http://testserver/foreignkeysource/1/',
'name': 'source-1',
'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/2/',
'name': 'source-2',
'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/3/',
'name': 'source-3',
'target': 'http://testserver/foreignkeytarget/1/'}
]
with self.assertNumQueries(1):
self.assertEqual(serializer.data, expected)
def test_reverse_foreign_key_retrieve(self):
queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
serializer = ForeignKeyTargetSerializer(queryset, many=True,
context={'request': request})
expected = [
{'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']},
{'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
{'url': 'http://testserver/foreignkeytarget/1/',
'name': 'target-1',
'sources': ['http://testserver/foreignkeysource/1/',
'http://testserver/foreignkeysource/2/',
'http://testserver/foreignkeysource/3/']},
{'url': 'http://testserver/foreignkeytarget/2/',
'name': 'target-2', 'sources': []},
]
with self.assertNumQueries(3):
self.assertEqual(serializer.data, expected)
def test_foreign_key_update(self):
data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'}
data = {'url': 'http://testserver/foreignkeysource/1/',
'name': 'source-1',
'target': 'http://testserver/foreignkeytarget/2/'}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
serializer = ForeignKeySourceSerializer(instance, data=data,
context={'request': request})
self.assertTrue(serializer.is_valid())
serializer.save()
self.assertEqual(serializer.data, data)
# Ensure source 1 is updated, and everything else is as expected
queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
serializer = ForeignKeySourceSerializer(queryset, many=True,
context={'request': request})
expected = [
{'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'},
{'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}
{'url': 'http://testserver/foreignkeysource/1/',
'name': 'source-1',
'target': 'http://testserver/foreignkeytarget/2/'},
{'url': 'http://testserver/foreignkeysource/2/',
'name': 'source-2',
'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/3/',
'name': 'source-3',
'target': 'http://testserver/foreignkeytarget/1/'}
]
self.assertEqual(serializer.data, expected)
def test_foreign_key_update_incorrect_type(self):
data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 2}
data = {'url': 'http://testserver/foreignkeysource/1/',
'name': 'source-1', 'target': 2}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
serializer = ForeignKeySourceSerializer(instance, data=data,
context={'request': request})
self.assertFalse(serializer.is_valid())
self.assertEqual(serializer.errors, {'target': ['Incorrect type. Expected URL string, received int.']})
self.assertEqual(serializer.errors, {
'target': ['Incorrect type. Expected URL string, received int.']})
def test_reverse_foreign_key_update(self):
data = {'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}
data = {'url': 'http://testserver/foreignkeytarget/2/',
'name': 'target-2',
'sources': ['http://testserver/foreignkeysource/1/',
'http://testserver/foreignkeysource/3/']}
instance = ForeignKeyTarget.objects.get(pk=2)
serializer = ForeignKeyTargetSerializer(instance, data=data, context={'request': request})
serializer = ForeignKeyTargetSerializer(instance, data=data,
context={'request': request})
self.assertTrue(serializer.is_valid())
# We shouldn't have saved anything to the db yet since save
# hasn't been called.
queryset = ForeignKeyTarget.objects.all()
new_serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
new_serializer = ForeignKeyTargetSerializer(queryset, many=True,
context={
'request': request})
expected = [
{'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']},
{'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
{'url': 'http://testserver/foreignkeytarget/1/',
'name': 'target-1',
'sources': ['http://testserver/foreignkeysource/1/',
'http://testserver/foreignkeysource/2/',
'http://testserver/foreignkeysource/3/']},
{'url': 'http://testserver/foreignkeytarget/2/',
'name': 'target-2', 'sources': []},
]
self.assertEqual(new_serializer.data, expected)
@ -274,16 +408,25 @@ class HyperlinkedForeignKeyTests(TestCase):
# Ensure target 2 is update, and everything else is as expected
queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
serializer = ForeignKeyTargetSerializer(queryset, many=True,
context={'request': request})
expected = [
{'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']},
{'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']},
{'url': 'http://testserver/foreignkeytarget/1/',
'name': 'target-1',
'sources': ['http://testserver/foreignkeysource/2/']},
{'url': 'http://testserver/foreignkeytarget/2/',
'name': 'target-2',
'sources': ['http://testserver/foreignkeysource/1/',
'http://testserver/foreignkeysource/3/']},
]
self.assertEqual(serializer.data, expected)
def test_foreign_key_create(self):
data = {'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'}
serializer = ForeignKeySourceSerializer(data=data, context={'request': request})
data = {'url': 'http://testserver/foreignkeysource/4/',
'name': 'source-4',
'target': 'http://testserver/foreignkeytarget/2/'}
serializer = ForeignKeySourceSerializer(data=data,
context={'request': request})
self.assertTrue(serializer.is_valid())
obj = serializer.save()
self.assertEqual(serializer.data, data)
@ -291,18 +434,31 @@ class HyperlinkedForeignKeyTests(TestCase):
# Ensure source 1 is updated, and everything else is as expected
queryset = ForeignKeySource.objects.all()
serializer = ForeignKeySourceSerializer(queryset, many=True, context={'request': request})
serializer = ForeignKeySourceSerializer(queryset, many=True,
context={'request': request})
expected = [
{'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/4/', 'name': 'source-4', 'target': 'http://testserver/foreignkeytarget/2/'},
{'url': 'http://testserver/foreignkeysource/1/',
'name': 'source-1',
'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/2/',
'name': 'source-2',
'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/3/',
'name': 'source-3',
'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/4/',
'name': 'source-4',
'target': 'http://testserver/foreignkeytarget/2/'},
]
self.assertEqual(serializer.data, expected)
def test_reverse_foreign_key_create(self):
data = {'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']}
serializer = ForeignKeyTargetSerializer(data=data, context={'request': request})
data = {'url': 'http://testserver/foreignkeytarget/3/',
'name': 'target-3',
'sources': ['http://testserver/foreignkeysource/1/',
'http://testserver/foreignkeysource/3/']}
serializer = ForeignKeyTargetSerializer(data=data,
context={'request': request})
self.assertTrue(serializer.is_valid())
obj = serializer.save()
self.assertEqual(serializer.data, data)
@ -310,20 +466,30 @@ class HyperlinkedForeignKeyTests(TestCase):
# Ensure target 4 is added, and everything else is as expected
queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True, context={'request': request})
serializer = ForeignKeyTargetSerializer(queryset, many=True,
context={'request': request})
expected = [
{'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/2/']},
{'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
{'url': 'http://testserver/foreignkeytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/3/']},
{'url': 'http://testserver/foreignkeytarget/1/',
'name': 'target-1',
'sources': ['http://testserver/foreignkeysource/2/']},
{'url': 'http://testserver/foreignkeytarget/2/',
'name': 'target-2', 'sources': []},
{'url': 'http://testserver/foreignkeytarget/3/',
'name': 'target-3',
'sources': ['http://testserver/foreignkeysource/1/',
'http://testserver/foreignkeysource/3/']},
]
self.assertEqual(serializer.data, expected)
def test_foreign_key_update_with_invalid_null(self):
data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': None}
data = {'url': 'http://testserver/foreignkeysource/1/',
'name': 'source-1', 'target': None}
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data, context={'request': request})
serializer = ForeignKeySourceSerializer(instance, data=data,
context={'request': request})
self.assertFalse(serializer.is_valid())
self.assertEqual(serializer.errors, {'target': ['This field may not be null.']})
self.assertEqual(serializer.errors,
{'target': ['This field may not be null.']})
@override_settings(ROOT_URLCONF='tests.test_relations_hyperlink')
@ -334,22 +500,32 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
for idx in range(1, 4):
if idx == 3:
target = None
source = NullableForeignKeySource(name='source-%d' % idx, target=target)
source = NullableForeignKeySource(name='source-%d' % idx,
target=target)
source.save()
def test_foreign_key_retrieve_with_null(self):
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
serializer = NullableForeignKeySourceSerializer(
queryset, many=True,
context={'request': request})
expected = [
{'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
{'url': 'http://testserver/nullableforeignkeysource/1/',
'name': 'source-1',
'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/2/',
'name': 'source-2',
'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/3/',
'name': 'source-3', 'target': None},
]
self.assertEqual(serializer.data, expected)
def test_foreign_key_create_with_valid_null(self):
data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request})
data = {'url': 'http://testserver/nullableforeignkeysource/4/',
'name': 'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data, context={
'request': request})
self.assertTrue(serializer.is_valid())
obj = serializer.save()
self.assertEqual(serializer.data, data)
@ -357,12 +533,20 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
# Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
serializer = NullableForeignKeySourceSerializer(
queryset, many=True,
context={'request': request})
expected = [
{'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
{'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
{'url': 'http://testserver/nullableforeignkeysource/1/',
'name': 'source-1',
'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/2/',
'name': 'source-2',
'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/3/',
'name': 'source-3', 'target': None},
{'url': 'http://testserver/nullableforeignkeysource/4/',
'name': 'source-4', 'target': None}
]
self.assertEqual(serializer.data, expected)
@ -371,9 +555,13 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
The emptystring should be interpreted as null in the context
of relationships.
"""
data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': ''}
expected_data = {'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data, context={'request': request})
data = {'url': 'http://testserver/nullableforeignkeysource/4/',
'name': 'source-4', 'target': ''}
expected_data = {
'url': 'http://testserver/nullableforeignkeysource/4/',
'name': 'source-4', 'target': None}
serializer = NullableForeignKeySourceSerializer(data=data, context={
'request': request})
self.assertTrue(serializer.is_valid())
obj = serializer.save()
self.assertEqual(serializer.data, expected_data)
@ -381,30 +569,47 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
# Ensure source 4 is created, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
serializer = NullableForeignKeySourceSerializer(
queryset, many=True,
context={'request': request})
expected = [
{'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
{'url': 'http://testserver/nullableforeignkeysource/4/', 'name': 'source-4', 'target': None}
{'url': 'http://testserver/nullableforeignkeysource/1/',
'name': 'source-1',
'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/2/',
'name': 'source-2',
'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/3/',
'name': 'source-3', 'target': None},
{'url': 'http://testserver/nullableforeignkeysource/4/',
'name': 'source-4', 'target': None}
]
self.assertEqual(serializer.data, expected)
def test_foreign_key_update_with_valid_null(self):
data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}
data = {'url': 'http://testserver/nullableforeignkeysource/1/',
'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request})
serializer = NullableForeignKeySourceSerializer(
instance, data=data,
context={'request': request})
self.assertTrue(serializer.is_valid())
serializer.save()
self.assertEqual(serializer.data, data)
# Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
serializer = NullableForeignKeySourceSerializer(
queryset, many=True,
context={'request': request})
expected = [
{'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None},
{'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
{'url': 'http://testserver/nullableforeignkeysource/1/',
'name': 'source-1', 'target': None},
{'url': 'http://testserver/nullableforeignkeysource/2/',
'name': 'source-2',
'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/3/',
'name': 'source-3', 'target': None},
]
self.assertEqual(serializer.data, expected)
@ -413,21 +618,33 @@ class HyperlinkedNullableForeignKeyTests(TestCase):
The emptystring should be interpreted as null in the context
of relationships.
"""
data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': ''}
expected_data = {'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None}
data = {'url': 'http://testserver/nullableforeignkeysource/1/',
'name': 'source-1', 'target': ''}
expected_data = {
'url': 'http://testserver/nullableforeignkeysource/1/',
'name': 'source-1', 'target': None}
instance = NullableForeignKeySource.objects.get(pk=1)
serializer = NullableForeignKeySourceSerializer(instance, data=data, context={'request': request})
serializer = NullableForeignKeySourceSerializer(
instance, data=data,
context={'request': request})
self.assertTrue(serializer.is_valid())
serializer.save()
self.assertEqual(serializer.data, expected_data)
# Ensure source 1 is updated, and everything else is as expected
queryset = NullableForeignKeySource.objects.all()
serializer = NullableForeignKeySourceSerializer(queryset, many=True, context={'request': request})
serializer = NullableForeignKeySourceSerializer(
queryset,
many=True,
context={'request': request})
expected = [
{'url': 'http://testserver/nullableforeignkeysource/1/', 'name': 'source-1', 'target': None},
{'url': 'http://testserver/nullableforeignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/3/', 'name': 'source-3', 'target': None},
{'url': 'http://testserver/nullableforeignkeysource/1/',
'name': 'source-1', 'target': None},
{'url': 'http://testserver/nullableforeignkeysource/2/',
'name': 'source-2',
'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/nullableforeignkeysource/3/',
'name': 'source-3', 'target': None},
]
self.assertEqual(serializer.data, expected)
@ -444,9 +661,14 @@ class HyperlinkedNullableOneToOneTests(TestCase):
def test_reverse_foreign_key_retrieve_with_null(self):
queryset = OneToOneTarget.objects.all()
serializer = NullableOneToOneTargetSerializer(queryset, many=True, context={'request': request})
serializer = NullableOneToOneTargetSerializer(
queryset,
many=True,
context={'request': request})
expected = [
{'url': 'http://testserver/onetoonetarget/1/', 'name': 'target-1', 'nullable_source': 'http://testserver/nullableonetoonesource/1/'},
{'url': 'http://testserver/onetoonetarget/2/', 'name': 'target-2', 'nullable_source': None},
{'url': 'http://testserver/onetoonetarget/1/', 'name': 'target-1',
'nullable_source': 'http://testserver/nullableonetoonesource/1/'},
{'url': 'http://testserver/onetoonetarget/2/', 'name': 'target-2',
'nullable_source': None},
]
self.assertEqual(serializer.data, expected)

View File

@ -248,7 +248,8 @@ class PKForeignKeyTests(TestCase):
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertFalse(serializer.is_valid())
self.assertEqual(serializer.errors, {'target': ['Incorrect type. Expected pk value, received %s.' % six.text_type.__name__]})
self.assertEqual(serializer.errors, {'target': [
'Incorrect type. Expected pk value, received %s.' % six.text_type.__name__]})
def test_reverse_foreign_key_update(self):
data = {'id': 2, 'name': 'target-2', 'sources': [1, 3]}
@ -319,7 +320,8 @@ class PKForeignKeyTests(TestCase):
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertFalse(serializer.is_valid())
self.assertEqual(serializer.errors, {'target': ['This field may not be null.']})
self.assertEqual(serializer.errors,
{'target': ['This field may not be null.']})
def test_foreign_key_with_unsaved(self):
source = ForeignKeySource(name='source-unsaved')
@ -345,9 +347,11 @@ class PKForeignKeyTests(TestCase):
Let's say we wanted to fill the non-nullable model field inside
Model.save(), we would make it empty and not required.
"""
class ModelSerializer(ForeignKeySourceSerializer):
class Meta(ForeignKeySourceSerializer.Meta):
extra_kwargs = {'target': {'required': False}}
serializer = ModelSerializer(data={'name': 'test'})
serializer.is_valid(raise_exception=True)
self.assertNotIn('target', serializer.validated_data)
@ -360,7 +364,8 @@ class PKNullableForeignKeyTests(TestCase):
for idx in range(1, 4):
if idx == 3:
target = None
source = NullableForeignKeySource(name='source-%d' % idx, target=target)
source = NullableForeignKeySource(name='source-%d' % idx,
target=target)
source.save()
def test_foreign_key_retrieve_with_null(self):

View File

@ -73,7 +73,8 @@ class SlugForeignKeyTests(TestCase):
queryset = ForeignKeyTarget.objects.all()
serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']},
{'id': 1, 'name': 'target-1',
'sources': ['source-1', 'source-2', 'source-3']},
{'id': 2, 'name': 'target-2', 'sources': []},
]
self.assertEqual(serializer.data, expected)
@ -107,10 +108,12 @@ class SlugForeignKeyTests(TestCase):
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertFalse(serializer.is_valid())
self.assertEqual(serializer.errors, {'target': ['Object with name=123 does not exist.']})
self.assertEqual(serializer.errors,
{'target': ['Object with name=123 does not exist.']})
def test_reverse_foreign_key_update(self):
data = {'id': 2, 'name': 'target-2', 'sources': ['source-1', 'source-3']}
data = {'id': 2, 'name': 'target-2',
'sources': ['source-1', 'source-3']}
instance = ForeignKeyTarget.objects.get(pk=2)
serializer = ForeignKeyTargetSerializer(instance, data=data)
self.assertTrue(serializer.is_valid())
@ -119,7 +122,8 @@ class SlugForeignKeyTests(TestCase):
queryset = ForeignKeyTarget.objects.all()
new_serializer = ForeignKeyTargetSerializer(queryset, many=True)
expected = [
{'id': 1, 'name': 'target-1', 'sources': ['source-1', 'source-2', 'source-3']},
{'id': 1, 'name': 'target-1',
'sources': ['source-1', 'source-2', 'source-3']},
{'id': 2, 'name': 'target-2', 'sources': []},
]
self.assertEqual(new_serializer.data, expected)
@ -157,7 +161,8 @@ class SlugForeignKeyTests(TestCase):
self.assertEqual(serializer.data, expected)
def test_reverse_foreign_key_create(self):
data = {'id': 3, 'name': 'target-3', 'sources': ['source-1', 'source-3']}
data = {'id': 3, 'name': 'target-3',
'sources': ['source-1', 'source-3']}
serializer = ForeignKeyTargetSerializer(data=data)
self.assertTrue(serializer.is_valid())
obj = serializer.save()
@ -179,7 +184,8 @@ class SlugForeignKeyTests(TestCase):
instance = ForeignKeySource.objects.get(pk=1)
serializer = ForeignKeySourceSerializer(instance, data=data)
self.assertFalse(serializer.is_valid())
self.assertEqual(serializer.errors, {'target': ['This field may not be null.']})
self.assertEqual(serializer.errors,
{'target': ['This field may not be null.']})
class SlugNullableForeignKeyTests(TestCase):
@ -189,7 +195,8 @@ class SlugNullableForeignKeyTests(TestCase):
for idx in range(1, 4):
if idx == 3:
target = None
source = NullableForeignKeySource(name='source-%d' % idx, target=target)
source = NullableForeignKeySource(name='source-%d' % idx,
target=target)
source.save()
def test_foreign_key_retrieve_with_null(self):

View File

@ -3,7 +3,7 @@ from __future__ import unicode_literals
import json
import re
from collections import MutableMapping, OrderedDict
from collections import MutableMapping
from django.conf.urls import include, url
from django.core.cache import cache
@ -13,11 +13,15 @@ from django.utils import six
from django.utils.safestring import SafeText
from django.utils.translation import ugettext_lazy as _
from rest_framework import permissions, serializers, status
from rest_framework import permissions, status
from rest_framework.fields import (
CharField, ChoiceField, HiddenField, MultipleChoiceField, OrderedDict
)
from rest_framework.renderers import (
BaseRenderer, BrowsableAPIRenderer, HTMLFormRenderer, JSONRenderer
)
from rest_framework.response import Response
from rest_framework.serializers import Serializer
from rest_framework.settings import api_settings
from rest_framework.test import APIRequestFactory
from rest_framework.views import APIView
@ -92,7 +96,7 @@ class EmptyGETView(APIView):
class HTMLView(APIView):
renderer_classes = (BrowsableAPIRenderer, )
renderer_classes = (BrowsableAPIRenderer,)
def get(self, request, **kwargs):
return Response('text')
@ -104,11 +108,14 @@ class HTMLView1(APIView):
def get(self, request, **kwargs):
return Response('text')
urlpatterns = [
url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB])),
url(r'^.*\.(?P<format>.+)$',
MockView.as_view(renderer_classes=[RendererA, RendererB])),
url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB])),
url(r'^cache$', MockGETView.as_view()),
url(r'^parseerror$', MockPOSTView.as_view(renderer_classes=[JSONRenderer, BrowsableAPIRenderer])),
url(r'^parseerror$', MockPOSTView.as_view(
renderer_classes=[JSONRenderer, BrowsableAPIRenderer])),
url(r'^html$', HTMLView.as_view()),
url(r'^html1$', HTMLView1.as_view()),
url(r'^empty$', EmptyGETView.as_view()),
@ -153,10 +160,13 @@ class RendererEndToEndTests(TestCase):
"""
End-to-end testing of renderers using an RendererMixin on a generic view.
"""
def test_default_renderer_serializes_content(self):
"""If the Accept header is not set the default renderer should serialize the response."""
"""If the Accept header is not set the default renderer should
serialize the response."""
resp = self.client.get('/')
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp['Content-Type'],
RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
@ -164,34 +174,42 @@ class RendererEndToEndTests(TestCase):
"""No response must be included in HEAD requests."""
resp = self.client.head('/')
self.assertEqual(resp.status_code, DUMMYSTATUS)
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp['Content-Type'],
RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp.content, six.b(''))
def test_default_renderer_serializes_content_on_accept_any(self):
"""If the Accept header is set to */* the default renderer should serialize the response."""
"""If the Accept header is set to */* the default renderer should
serialize the response."""
resp = self.client.get('/', HTTP_ACCEPT='*/*')
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp['Content-Type'],
RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_serializes_content_default_case(self):
"""If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for the default renderer)"""
"""If the Accept header is set the specified renderer should serialize
the response. (In this case we check that works for the default
renderer)"""
resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type)
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp['Content-Type'],
RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_serializes_content_non_default_case(self):
"""If the Accept header is set the specified renderer should serialize the response.
(In this case we check that works for a non-default renderer)"""
"""If the Accept header is set the specified renderer should serialize
the response. (In this case we check that works for a non-default
renderer)"""
resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type)
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp['Content-Type'],
RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_unsatisfiable_accept_header_on_request_returns_406_status(self):
"""If the Accept header is unsatisfiable we should return a 406 Not Acceptable response."""
"""If the Accept header is unsatisfiable we should return a
406 Not Acceptable response."""
resp = self.client.get('/', HTTP_ACCEPT='foo/bar')
self.assertEqual(resp.status_code, status.HTTP_406_NOT_ACCEPTABLE)
@ -203,34 +221,41 @@ class RendererEndToEndTests(TestCase):
RendererB.format
)
resp = self.client.get('/' + param)
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp['Content-Type'],
RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_serializes_content_on_format_kwargs(self):
"""If a 'format' keyword arg is specified, the renderer with the matching
format attribute should serialize the response."""
"""If a 'format' keyword arg is specified, the renderer with the
matching format attribute should serialize the response."""
resp = self.client.get('/something.formatb')
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp['Content-Type'],
RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_is_used_on_format_query_with_matching_accept(self):
def test_specified_renderer_is_used_on_format_query_with_matching_accept(
self):
"""If both a 'format' query and a matching Accept header specified,
the renderer with the matching format attribute should serialize the response."""
the renderer with the matching format attribute should serialize the
response."""
param = '?%s=%s' % (
api_settings.URL_FORMAT_OVERRIDE,
RendererB.format
)
resp = self.client.get('/' + param,
HTTP_ACCEPT=RendererB.media_type)
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp['Content-Type'],
RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_parse_error_renderers_browsable_api(self):
"""Invalid data should still render the browsable API correctly."""
resp = self.client.post('/parseerror', data='foobar', content_type='application/json', HTTP_ACCEPT='text/html')
resp = self.client.post('/parseerror', data='foobar',
content_type='application/json',
HTTP_ACCEPT='text/html')
self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
self.assertEqual(resp.status_code, status.HTTP_400_BAD_REQUEST)
@ -358,18 +383,22 @@ class JSONRendererTests(TestCase):
obj = {'foo': ['bar', 'baz']}
renderer = JSONRenderer()
content = renderer.render(obj, 'application/json; indent=2')
self.assertEqual(strip_trailing_whitespace(content.decode('utf-8')), _indented_repr)
self.assertEqual(strip_trailing_whitespace(content.decode('utf-8')),
_indented_repr)
class UnicodeJSONRendererTests(TestCase):
"""
Tests specific for the Unicode JSON Renderer
"""
def test_proper_encoding(self):
obj = {'countries': ['United Kingdom', 'France', 'España']}
renderer = JSONRenderer()
content = renderer.render(obj, 'application/json')
self.assertEqual(content, '{"countries":["United Kingdom","France","España"]}'.encode('utf-8'))
self.assertEqual(content,
'{"countries":["United Kingdom","France","España"]}'.
encode('utf-8'))
def test_u2028_u2029(self):
# The \u2028 and \u2029 characters should be escaped,
@ -378,20 +407,26 @@ class UnicodeJSONRendererTests(TestCase):
obj = {'should_escape': '\u2028\u2029'}
renderer = JSONRenderer()
content = renderer.render(obj, 'application/json')
self.assertEqual(content, '{"should_escape":"\\u2028\\u2029"}'.encode('utf-8'))
self.assertEqual(content,
'{"should_escape":"\\u2028\\u2029"}'.encode('utf-8'))
class AsciiJSONRendererTests(TestCase):
"""
Tests specific for the Unicode JSON Renderer
"""
def test_proper_encoding(self):
class AsciiJSONRenderer(JSONRenderer):
ensure_ascii = True
obj = {'countries': ['United Kingdom', 'France', 'España']}
renderer = AsciiJSONRenderer()
content = renderer.render(obj, 'application/json')
self.assertEqual(content, '{"countries":["United Kingdom","France","Espa\\u00f1a"]}'.encode('utf-8'))
self.assertEqual(content,
'{"countries":["United Kingdom","France",'
'"Espa\\u00f1a"]}'.encode(
'utf-8'))
# Tests for caching issue, #346
@ -400,6 +435,7 @@ class CacheRenderTest(TestCase):
"""
Tests specific to caching responses
"""
def test_head_caching(self):
"""
Test caching of HEAD requests
@ -447,8 +483,8 @@ class TestJSONIndentationStyles:
class TestHiddenFieldHTMLFormRenderer(TestCase):
def test_hidden_field_rendering(self):
class TestSerializer(serializers.Serializer):
published = serializers.HiddenField(default=True)
class TestSerializer(Serializer):
published = HiddenField(default=True)
serializer = TestSerializer(data={})
serializer.is_valid()
@ -460,8 +496,8 @@ class TestHiddenFieldHTMLFormRenderer(TestCase):
class TestHTMLFormRenderer(TestCase):
def setUp(self):
class TestSerializer(serializers.Serializer):
test_field = serializers.CharField()
class TestSerializer(Serializer):
test_field = CharField()
self.renderer = HTMLFormRenderer()
self.serializer = TestSerializer(data={})
@ -491,9 +527,9 @@ class TestChoiceFieldHTMLFormRenderer(TestCase):
def setUp(self):
choices = ((1, 'Option1'), (2, 'Option2'), (12, 'Option12'))
class TestSerializer(serializers.Serializer):
test_field = serializers.ChoiceField(choices=choices,
initial=2)
class TestSerializer(Serializer):
test_field = ChoiceField(choices=choices,
initial=2)
self.TestSerializer = TestSerializer
self.renderer = HTMLFormRenderer()
@ -535,8 +571,8 @@ class TestMultipleChoiceFieldHTMLFormRenderer(TestCase):
choices = (('1', 'Option1'), ('2', 'Option2'), ('12', 'Option12'),
('}', 'OptionBrace'))
class TestSerializer(serializers.Serializer):
test_field = serializers.MultipleChoiceField(choices=choices)
class TestSerializer(Serializer):
test_field = MultipleChoiceField(choices=choices)
serializer = TestSerializer(data={'test_field': ['12']})
serializer.is_valid()
@ -554,8 +590,8 @@ class TestMultipleChoiceFieldHTMLFormRenderer(TestCase):
def test_render_selected_option_with_integer_option_ids(self):
choices = ((1, 'Option1'), (2, 'Option2'), (12, 'Option12'))
class TestSerializer(serializers.Serializer):
test_field = serializers.MultipleChoiceField(choices=choices)
class TestSerializer(Serializer):
test_field = MultipleChoiceField(choices=choices)
serializer = TestSerializer(data={'test_field': ['12']})
serializer.is_valid()

View File

@ -32,6 +32,7 @@ class MockJsonRenderer(BaseRenderer):
class MockTextMediaRenderer(BaseRenderer):
media_type = 'text/html'
DUMMYSTATUS = status.HTTP_200_OK
DUMMYCONTENT = 'dummycontent'
@ -77,7 +78,8 @@ class MockViewSettingContentType(APIView):
renderer_classes = (RendererA, RendererB, RendererC)
def get(self, request, **kwargs):
return Response(DUMMYCONTENT, status=DUMMYSTATUS, content_type='setbyview')
return Response(DUMMYCONTENT, status=DUMMYSTATUS,
content_type='setbyview')
class JSONView(APIView):
@ -89,7 +91,7 @@ class JSONView(APIView):
class HTMLView(APIView):
renderer_classes = (BrowsableAPIRenderer, )
renderer_classes = (BrowsableAPIRenderer,)
def get(self, request, **kwargs):
return Response('text')
@ -117,17 +119,20 @@ class HTMLNewModelView(generics.ListCreateAPIView):
new_model_viewset_router = routers.DefaultRouter()
new_model_viewset_router.register(r'', HTMLNewModelViewSet)
urlpatterns = [
url(r'^setbyview$', MockViewSettingContentType.as_view(renderer_classes=[RendererA, RendererB, RendererC])),
url(r'^.*\.(?P<format>.+)$', MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])),
url(r'^$', MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])),
url(r'^setbyview$', MockViewSettingContentType.as_view(
renderer_classes=[RendererA, RendererB, RendererC])),
url(r'^.*\.(?P<format>.+)$',
MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])),
url(r'^$',
MockView.as_view(renderer_classes=[RendererA, RendererB, RendererC])),
url(r'^html$', HTMLView.as_view()),
url(r'^json$', JSONView.as_view()),
url(r'^html1$', HTMLView1.as_view()),
url(r'^html_new_model$', HTMLNewModelView.as_view()),
url(r'^html_new_model_viewset', include(new_model_viewset_router.urls)),
url(r'^restframework', include('rest_framework.urls', namespace='rest_framework'))
url(r'^restframework',
include('rest_framework.urls', namespace='rest_framework'))
]
@ -137,10 +142,13 @@ class RendererIntegrationTests(TestCase):
"""
End-to-end testing of renderers using an ResponseMixin on a generic view.
"""
def test_default_renderer_serializes_content(self):
"""If the Accept header is not set the default renderer should serialize the response."""
"""If the Accept header is not set the default renderer should
serialize the response."""
resp = self.client.get('/')
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp['Content-Type'],
RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
@ -148,29 +156,36 @@ class RendererIntegrationTests(TestCase):
"""No response must be included in HEAD requests."""
resp = self.client.head('/')
self.assertEqual(resp.status_code, DUMMYSTATUS)
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp['Content-Type'],
RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp.content, six.b(''))
def test_default_renderer_serializes_content_on_accept_any(self):
"""If the Accept header is set to */* the default renderer should serialize the response."""
"""If the Accept header is set to */* the default renderer should
serialize the response."""
resp = self.client.get('/', HTTP_ACCEPT='*/*')
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp['Content-Type'],
RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_serializes_content_default_case(self):
"""If the Accept header is set the specified renderer should serialize the response.
"""If the Accept header is set the specified renderer should serialize
the response.
(In this case we check that works for the default renderer)"""
resp = self.client.get('/', HTTP_ACCEPT=RendererA.media_type)
self.assertEqual(resp['Content-Type'], RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp['Content-Type'],
RendererA.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_A_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_serializes_content_non_default_case(self):
"""If the Accept header is set the specified renderer should serialize the response.
"""If the Accept header is set the specified renderer should serialize
the response.
(In this case we check that works for a non-default renderer)"""
resp = self.client.get('/', HTTP_ACCEPT=RendererB.media_type)
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp['Content-Type'],
RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
@ -178,24 +193,29 @@ class RendererIntegrationTests(TestCase):
"""If a 'format' query is specified, the renderer with the matching
format attribute should serialize the response."""
resp = self.client.get('/?format=%s' % RendererB.format)
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp['Content-Type'],
RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_serializes_content_on_format_kwargs(self):
"""If a 'format' keyword arg is specified, the renderer with the matching
format attribute should serialize the response."""
"""If a 'format' keyword arg is specified, the renderer with the
matching format attribute should serialize the response."""
resp = self.client.get('/something.formatb')
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp['Content-Type'],
RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
def test_specified_renderer_is_used_on_format_query_with_matching_accept(self):
def test_specified_renderer_is_used_on_format_query_with_matching_accept(
self):
"""If both a 'format' query and a matching Accept header specified,
the renderer with the matching format attribute should serialize the response."""
the renderer with the matching format attribute should serialize the
response."""
resp = self.client.get('/?format=%s' % RendererB.format,
HTTP_ACCEPT=RendererB.media_type)
self.assertEqual(resp['Content-Type'], RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp['Content-Type'],
RendererB.media_type + '; charset=utf-8')
self.assertEqual(resp.content, RENDERER_B_SERIALIZER(DUMMYCONTENT))
self.assertEqual(resp.status_code, DUMMYSTATUS)
@ -203,12 +223,14 @@ class RendererIntegrationTests(TestCase):
@override_settings(ROOT_URLCONF='tests.test_response')
class UnsupportedMediaTypeTests(TestCase):
def test_should_allow_posting_json(self):
response = self.client.post('/json', data='{"test": 123}', content_type='application/json')
response = self.client.post('/json', data='{"test": 123}',
content_type='application/json')
self.assertEqual(response.status_code, 200)
def test_should_not_allow_posting_xml(self):
response = self.client.post('/json', data='<test>123</test>', content_type='application/xml')
response = self.client.post('/json', data='<test>123</test>',
content_type='application/xml')
self.assertEqual(response.status_code, 415)
@ -223,6 +245,7 @@ class Issue122Tests(TestCase):
"""
Tests that covers #122.
"""
def test_only_html_renderer(self):
"""
Test if no infinite recursion occurs.
@ -241,6 +264,7 @@ class Issue467Tests(TestCase):
"""
Tests for #467
"""
def test_form_has_label_and_help_text(self):
resp = self.client.get('/html_new_model')
self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
@ -253,6 +277,7 @@ class Issue807Tests(TestCase):
"""
Covers #807
"""
def test_does_not_append_charset_by_default(self):
"""
Renderers don't include a charset unless set explicitly.
@ -269,7 +294,8 @@ class Issue807Tests(TestCase):
"""
headers = {"HTTP_ACCEPT": RendererC.media_type}
resp = self.client.get('/', **headers)
expected = "{0}; charset={1}".format(RendererC.media_type, RendererC.charset)
expected = "{0}; charset={1}".format(RendererC.media_type,
RendererC.charset)
self.assertEqual(expected, resp['Content-Type'])
def test_content_type_set_explicitly_on_response(self):

View File

@ -6,8 +6,11 @@ import re
import pytest
from rest_framework import serializers
from rest_framework.compat import unicode_repr
from rest_framework.fields import (
CharField, IntegerField, RegexField, ValidationError
)
from rest_framework.serializers import BaseSerializer, Serializer
from .utils import MockObject
@ -17,9 +20,10 @@ from .utils import MockObject
class TestSerializer:
def setup(self):
class ExampleSerializer(serializers.Serializer):
char = serializers.CharField()
integer = serializers.IntegerField()
class ExampleSerializer(Serializer):
char = CharField()
integer = IntegerField()
self.Serializer = ExampleSerializer
def test_valid_serializer(self):
@ -47,6 +51,7 @@ class TestSerializer:
def test_missing_attribute_during_serialization(self):
class MissingAttributes:
pass
instance = MissingAttributes()
serializer = self.Serializer(instance)
with pytest.raises(AttributeError):
@ -55,6 +60,7 @@ class TestSerializer:
def test_data_access_before_save_raises_error(self):
def create(validated_data):
return validated_data
serializer = self.Serializer(data={'char': 'abc', 'integer': 123})
serializer.create = create
assert serializer.is_valid()
@ -71,24 +77,24 @@ class TestSerializer:
class TestValidateMethod:
def test_non_field_error_validate_method(self):
class ExampleSerializer(serializers.Serializer):
char = serializers.CharField()
integer = serializers.IntegerField()
class ExampleSerializer(Serializer):
char = CharField()
integer = IntegerField()
def validate(self, attrs):
raise serializers.ValidationError('Non field error')
raise ValidationError('Non field error')
serializer = ExampleSerializer(data={'char': 'abc', 'integer': 123})
assert not serializer.is_valid()
assert serializer.errors == {'non_field_errors': ['Non field error']}
def test_field_error_validate_method(self):
class ExampleSerializer(serializers.Serializer):
char = serializers.CharField()
integer = serializers.IntegerField()
class ExampleSerializer(Serializer):
char = CharField()
integer = IntegerField()
def validate(self, attrs):
raise serializers.ValidationError({'char': 'Field error'})
raise ValidationError({'char': 'Field error'})
serializer = ExampleSerializer(data={'char': 'abc', 'integer': 123})
assert not serializer.is_valid()
@ -97,7 +103,7 @@ class TestValidateMethod:
class TestBaseSerializer:
def setup(self):
class ExampleSerializer(serializers.BaseSerializer):
class ExampleSerializer(BaseSerializer):
def to_representation(self, obj):
return {
'id': obj['id'],
@ -167,15 +173,15 @@ class TestStarredSource:
}
def setup(self):
class NestedSerializer1(serializers.Serializer):
a = serializers.IntegerField()
b = serializers.IntegerField()
class NestedSerializer1(Serializer):
a = IntegerField()
b = IntegerField()
class NestedSerializer2(serializers.Serializer):
c = serializers.IntegerField()
d = serializers.IntegerField()
class NestedSerializer2(Serializer):
c = IntegerField()
d = IntegerField()
class TestSerializer(serializers.Serializer):
class TestSerializer(Serializer):
nested1 = NestedSerializer1(source='*')
nested2 = NestedSerializer2(source='*')
@ -205,8 +211,8 @@ class TestStarredSource:
class TestIncorrectlyConfigured:
def test_incorrect_field_name(self):
class ExampleSerializer(serializers.Serializer):
incorrect_name = serializers.IntegerField()
class ExampleSerializer(Serializer):
incorrect_name = IntegerField()
class ExampleObject:
def __init__(self):
@ -226,8 +232,8 @@ class TestIncorrectlyConfigured:
class TestUnicodeRepr:
def test_unicode_repr(self):
class ExampleSerializer(serializers.Serializer):
example = serializers.CharField()
class ExampleSerializer(Serializer):
example = CharField()
class ExampleObject:
def __init__(self):
@ -246,9 +252,10 @@ class TestNotRequiredOutput:
"""
'required=False' should allow a dictionary key to be missing in output.
"""
class ExampleSerializer(serializers.Serializer):
omitted = serializers.CharField(required=False)
included = serializers.CharField()
class ExampleSerializer(Serializer):
omitted = CharField(required=False)
included = CharField()
serializer = ExampleSerializer(data={'included': 'abc'})
serializer.is_valid()
@ -258,9 +265,10 @@ class TestNotRequiredOutput:
"""
'required=False' should allow an object attribute to be missing in output.
"""
class ExampleSerializer(serializers.Serializer):
omitted = serializers.CharField(required=False)
included = serializers.CharField()
class ExampleSerializer(Serializer):
omitted = CharField(required=False)
included = CharField()
def create(self, validated_data):
return MockObject(**validated_data)
@ -277,9 +285,10 @@ class TestNotRequiredOutput:
We need to handle this as the field will have an implicit
'required=False', but it should still have a value.
"""
class ExampleSerializer(serializers.Serializer):
omitted = serializers.CharField(default='abc')
included = serializers.CharField()
class ExampleSerializer(Serializer):
omitted = CharField(default='abc')
included = CharField()
serializer = ExampleSerializer({'included': 'abc'})
with pytest.raises(KeyError):
@ -292,9 +301,10 @@ class TestNotRequiredOutput:
We need to handle this as the field will have an implicit
'required=False', but it should still have a value.
"""
class ExampleSerializer(serializers.Serializer):
omitted = serializers.CharField(default='abc')
included = serializers.CharField()
class ExampleSerializer(Serializer):
omitted = CharField(default='abc')
included = CharField()
instance = MockObject(included='abc')
serializer = ExampleSerializer(instance)
@ -308,9 +318,10 @@ class TestCacheSerializerData:
Caching serializer data with pickle will drop the serializer info,
but does preserve the data itself.
"""
class ExampleSerializer(serializers.Serializer):
field1 = serializers.CharField()
field2 = serializers.CharField()
class ExampleSerializer(Serializer):
field1 = CharField()
field2 = CharField()
serializer = ExampleSerializer({'field1': 'a', 'field2': 'b'})
pickled = pickle.dumps(serializer.data)
@ -320,9 +331,10 @@ class TestCacheSerializerData:
class TestDefaultInclusions:
def setup(self):
class ExampleSerializer(serializers.Serializer):
char = serializers.CharField(read_only=True, default='abc')
integer = serializers.IntegerField()
class ExampleSerializer(Serializer):
char = CharField(read_only=True, default='abc')
integer = IntegerField()
self.Serializer = ExampleSerializer
def test_default_should_included_on_create(self):
@ -340,7 +352,8 @@ class TestDefaultInclusions:
def test_default_should_not_be_included_on_partial_update(self):
instance = MockObject(char='def', integer=123)
serializer = self.Serializer(instance, data={'integer': 456}, partial=True)
serializer = self.Serializer(instance, data={'integer': 456},
partial=True)
assert serializer.is_valid()
assert serializer.validated_data == {'integer': 456}
assert serializer.errors == {}
@ -348,8 +361,9 @@ class TestDefaultInclusions:
class TestSerializerValidationWithCompiledRegexField:
def setup(self):
class ExampleSerializer(serializers.Serializer):
name = serializers.RegexField(re.compile(r'\d'), required=True)
class ExampleSerializer(Serializer):
name = RegexField(re.compile(r'\d'), required=True)
self.Serializer = ExampleSerializer
def test_validation_success(self):

View File

@ -1,12 +1,17 @@
from django.utils.datastructures import MultiValueDict
from rest_framework import serializers
from rest_framework.fields import (
BooleanField, IntegerField, ListField, MultipleChoiceField,
ValidationError
)
from rest_framework.serializers import ListSerializer, Serializer
class BasicObject:
"""
A mock object for testing serializer save behavior.
"""
def __init__(self, **kwargs):
self._data = kwargs
for key, value in kwargs.items():
@ -28,8 +33,9 @@ class TestListSerializer:
"""
def setup(self):
class IntegerListSerializer(serializers.ListSerializer):
child = serializers.IntegerField()
class IntegerListSerializer(ListSerializer):
child = IntegerField()
self.Serializer = IntegerListSerializer
def test_validate(self):
@ -59,14 +65,14 @@ class TestListSerializerContainingNestedSerializer:
"""
def setup(self):
class TestSerializer(serializers.Serializer):
integer = serializers.IntegerField()
boolean = serializers.BooleanField()
class TestSerializer(Serializer):
integer = IntegerField()
boolean = BooleanField()
def create(self, validated_data):
return BasicObject(**validated_data)
class ObjectListSerializer(serializers.ListSerializer):
class ObjectListSerializer(ListSerializer):
child = TestSerializer()
self.Serializer = ObjectListSerializer
@ -145,9 +151,9 @@ class TestNestedListSerializer:
"""
def setup(self):
class TestSerializer(serializers.Serializer):
integers = serializers.ListSerializer(child=serializers.IntegerField())
booleans = serializers.ListSerializer(child=serializers.BooleanField())
class TestSerializer(Serializer):
integers = ListSerializer(child=IntegerField())
booleans = ListSerializer(child=BooleanField())
def create(self, validated_data):
return BasicObject(**validated_data)
@ -224,15 +230,15 @@ class TestNestedListSerializer:
class TestNestedListOfListsSerializer:
def setup(self):
class TestSerializer(serializers.Serializer):
integers = serializers.ListSerializer(
child=serializers.ListSerializer(
child=serializers.IntegerField()
class TestSerializer(Serializer):
integers = ListSerializer(
child=ListSerializer(
child=IntegerField()
)
)
booleans = serializers.ListSerializer(
child=serializers.ListSerializer(
child=serializers.BooleanField()
booleans = ListSerializer(
child=ListSerializer(
child=BooleanField()
)
)
@ -277,12 +283,13 @@ class TestNestedListOfListsSerializer:
class TestListSerializerClass:
"""Tests for a custom list_serializer_class."""
def test_list_serializer_class_validate(self):
class CustomListSerializer(serializers.ListSerializer):
def validate(self, attrs):
raise serializers.ValidationError('Non field error')
class TestSerializer(serializers.Serializer):
def test_list_serializer_class_validate(self):
class CustomListSerializer(ListSerializer):
def validate(self, attrs):
raise ValidationError('Non field error')
class TestSerializer(Serializer):
class Meta:
list_serializer_class = CustomListSerializer
@ -299,9 +306,11 @@ class TestSerializerPartialUsage:
Regression test for Github issue #2761.
"""
def test_partial_listfield(self):
class ListSerializer(serializers.Serializer):
listdata = serializers.ListField()
class ListSerializer(Serializer):
listdata = ListField()
serializer = ListSerializer(data=MultiValueDict(), partial=True)
result = serializer.to_internal_value(data={})
assert "listdata" not in result
@ -310,9 +319,11 @@ class TestSerializerPartialUsage:
assert serializer.errors == {}
def test_partial_multiplechoice(self):
class MultipleChoiceSerializer(serializers.Serializer):
multiplechoice = serializers.MultipleChoiceField(choices=[1, 2, 3])
serializer = MultipleChoiceSerializer(data=MultiValueDict(), partial=True)
class MultipleChoiceSerializer(Serializer):
multiplechoice = MultipleChoiceField(choices=[1, 2, 3])
serializer = MultipleChoiceSerializer(data=MultiValueDict(),
partial=True)
result = serializer.to_internal_value(data={})
assert "multiplechoice" not in result
assert serializer.is_valid()

View File

@ -1,15 +1,16 @@
from django.http import QueryDict
from rest_framework import serializers
from rest_framework.fields import IntegerField, MultipleChoiceField
from rest_framework.serializers import ListSerializer, Serializer
class TestNestedSerializer:
def setup(self):
class NestedSerializer(serializers.Serializer):
one = serializers.IntegerField(max_value=10)
two = serializers.IntegerField(max_value=10)
class NestedSerializer(Serializer):
one = IntegerField(max_value=10)
two = IntegerField(max_value=10)
class TestSerializer(serializers.Serializer):
class TestSerializer(Serializer):
nested = NestedSerializer()
self.Serializer = TestSerializer
@ -50,10 +51,10 @@ class TestNestedSerializer:
class TestNotRequiredNestedSerializer:
def setup(self):
class NestedSerializer(serializers.Serializer):
one = serializers.IntegerField(max_value=10)
class NestedSerializer(Serializer):
one = IntegerField(max_value=10)
class TestSerializer(serializers.Serializer):
class TestSerializer(Serializer):
nested = NestedSerializer(required=False)
self.Serializer = TestSerializer
@ -79,10 +80,10 @@ class TestNotRequiredNestedSerializer:
class TestNestedSerializerWithMany:
def setup(self):
class NestedSerializer(serializers.Serializer):
example = serializers.IntegerField(max_value=10)
class NestedSerializer(Serializer):
example = IntegerField(max_value=10)
class TestSerializer(serializers.Serializer):
class TestSerializer(Serializer):
allow_null = NestedSerializer(many=True, allow_null=True)
not_allow_null = NestedSerializer(many=True)
allow_empty = NestedSerializer(many=True, allow_empty=True)
@ -119,7 +120,8 @@ class TestNestedSerializerWithMany:
assert not serializer.is_valid()
expected_errors = {'not_allow_null': [serializer.error_messages['null']]}
expected_errors = {
'not_allow_null': [serializer.error_messages['null']]}
assert serializer.errors == expected_errors
def test_run_the_field_validation_even_if_the_field_is_null(self):
@ -171,16 +173,17 @@ class TestNestedSerializerWithMany:
assert not serializer.is_valid()
expected_errors = {'not_allow_empty': {'non_field_errors': [serializers.ListSerializer.default_error_messages['empty']]}}
expected_errors = {'not_allow_empty': {'non_field_errors': [
ListSerializer.default_error_messages['empty']]}}
assert serializer.errors == expected_errors
class TestNestedSerializerWithList:
def setup(self):
class NestedSerializer(serializers.Serializer):
example = serializers.MultipleChoiceField(choices=[1, 2, 3])
class NestedSerializer(Serializer):
example = MultipleChoiceField(choices=[1, 2, 3])
class TestSerializer(serializers.Serializer):
class TestSerializer(Serializer):
nested = NestedSerializer()
self.Serializer = TestSerializer

View File

@ -43,7 +43,8 @@ urlpatterns = [
url(r'^resource/customname$', CustomNameResourceInstance.as_view()),
url(r'^resource/(?P<key>[0-9]+)$', ResourceInstance.as_view()),
url(r'^resource/(?P<key>[0-9]+)/$', NestedResourceRoot.as_view()),
url(r'^resource/(?P<key>[0-9]+)/(?P<other>[A-Za-z]+)$', NestedResourceInstance.as_view()),
url(r'^resource/(?P<key>[0-9]+)/(?P<other>[A-Za-z]+)$',
NestedResourceInstance.as_view()),
]
@ -52,6 +53,7 @@ class BreadcrumbTests(TestCase):
"""
Tests the breadcrumb functionality used by the HTML renderer.
"""
def test_root_breadcrumbs(self):
url = '/'
self.assertEqual(
@ -130,6 +132,7 @@ class ResolveModelTests(TestCase):
provided argument is a Django model class itself, or a properly
formatted string representation of one.
"""
def test_resolve_django_model(self):
resolved_model = _resolve_model(BasicModel)
self.assertEqual(resolved_model, BasicModel)

View File

@ -1,4 +1,5 @@
from django.core.exceptions import ObjectDoesNotExist
from rest_framework.compat import NoReverseMatch