Merge remote-tracking branch 'upstream/master' into route_fix

This commit is contained in:
Nik 2016-08-11 21:27:43 +03:00
commit 070094e5ec
30 changed files with 408 additions and 116 deletions

View File

@ -18,8 +18,7 @@ REST framework commercially we strongly encourage you to invest in its
continued development by **[signing up for a paid plan][funding]**. continued development by **[signing up for a paid plan][funding]**.
The initial aim is to provide a single full-time position on REST framework. The initial aim is to provide a single full-time position on REST framework.
Right now we're over 58% of the way towards achieving that. *Every single sign-up makes a significant impact towards making that possible.*
*Every single sign-up makes a significant impact.*
<p align="center"> <p align="center">
<a href="http://jobs.rover.com/"><img src="https://raw.githubusercontent.com/tomchristie/django-rest-framework/master/docs/img/premium/rover-readme.png"/></a> <a href="http://jobs.rover.com/"><img src="https://raw.githubusercontent.com/tomchristie/django-rest-framework/master/docs/img/premium/rover-readme.png"/></a>

View File

@ -132,7 +132,7 @@ This permission is suitable if you want to your API to allow read permissions to
## DjangoModelPermissions ## DjangoModelPermissions
This permission class ties into Django's standard `django.contrib.auth` [model permissions][contribauth]. This permission must only be applied to views that has a `.queryset` property set. Authorization will only be granted if the user *is authenticated* and has the *relevant model permissions* assigned. This permission class ties into Django's standard `django.contrib.auth` [model permissions][contribauth]. This permission must only be applied to views that have a `.queryset` property set. Authorization will only be granted if the user *is authenticated* and has the *relevant model permissions* assigned.
* `POST` requests require the user to have the `add` permission on the model. * `POST` requests require the user to have the `add` permission on the model.
* `PUT` and `PATCH` requests require the user to have the `change` permission on the model. * `PUT` and `PATCH` requests require the user to have the `change` permission on the model.

View File

@ -71,8 +71,8 @@ You can also set the versioning scheme on an individual view. Typically you won'
The following settings keys are also used to control versioning: The following settings keys are also used to control versioning:
* `DEFAULT_VERSION`. The value that should be used for `request.version` when no versioning information is present. Defaults to `None`. * `DEFAULT_VERSION`. The value that should be used for `request.version` when no versioning information is present. Defaults to `None`.
* `ALLOWED_VERSIONS`. If set, this value will restrict the set of versions that may be returned by the versioning scheme, and will raise an error if the provided version if not in this set. Note that the value used for the `DEFAULT_VERSION` setting is always considered to be part of the `ALLOWED_VERSIONS` set. Defaults to `None`. * `ALLOWED_VERSIONS`. If set, this value will restrict the set of versions that may be returned by the versioning scheme, and will raise an error if the provided version is not in this set. Note that the value used for the `DEFAULT_VERSION` setting is always considered to be part of the `ALLOWED_VERSIONS` set (unless it is `None`). Defaults to `None`.
* `VERSION_PARAM`. The string that should used for any versioning parameters, such as in the media type or URL query parameters. Defaults to `'version'`. * `VERSION_PARAM`. The string that should be used for any versioning parameters, such as in the media type or URL query parameters. Defaults to `'version'`.
You can also set your versioning class plus those three values on a per-view or a per-viewset basis by defining your own versioning scheme and using the `default_version`, `allowed_versions` and `version_param` class variables. For example, if you want to use `URLPathVersioning`: You can also set your versioning class plus those three values on a per-view or a per-viewset basis by defining your own versioning scheme and using the `default_version`, `allowed_versions` and `version_param` class variables. For example, if you want to use `URLPathVersioning`:

View File

@ -68,8 +68,7 @@ REST framework commercially we strongly encourage you to invest in its
continued development by **[signing up for a paid plan][funding]**. continued development by **[signing up for a paid plan][funding]**.
The initial aim is to provide a single full-time position on REST framework. The initial aim is to provide a single full-time position on REST framework.
Right now we're over 58% of the way towards achieving that. *Every single sign-up makes a significant impact towards making that possible.*
*Every single sign-up makes a significant impact.*
<ul class="premium-promo promo"> <ul class="premium-promo promo">
<li><a href="http://jobs.rover.com/" style="background-image: url(https://fund-rest-framework.s3.amazonaws.com/rover_130x130.png)">Rover.com</a></li> <li><a href="http://jobs.rover.com/" style="background-image: url(https://fund-rest-framework.s3.amazonaws.com/rover_130x130.png)">Rover.com</a></li>
@ -87,7 +86,7 @@ Right now we're over 58% of the way towards achieving that.
REST framework requires the following: REST framework requires the following:
* Python (2.7, 3.2, 3.3, 3.4, 3.5) * Python (2.7, 3.2, 3.3, 3.4, 3.5)
* Django (1.7+, 1.8, 1.9) * Django (1.8, 1.9, 1.10)
The following packages are optional: The following packages are optional:

View File

@ -672,6 +672,7 @@ class NullBooleanField(Field):
class CharField(Field): class CharField(Field):
default_error_messages = { default_error_messages = {
'invalid': _('Not a valid string.'),
'blank': _('This field may not be blank.'), 'blank': _('This field may not be blank.'),
'max_length': _('Ensure this field has no more than {max_length} characters.'), 'max_length': _('Ensure this field has no more than {max_length} characters.'),
'min_length': _('Ensure this field has at least {min_length} characters.') 'min_length': _('Ensure this field has at least {min_length} characters.')
@ -702,6 +703,11 @@ class CharField(Field):
return super(CharField, self).run_validation(data) return super(CharField, self).run_validation(data)
def to_internal_value(self, data): def to_internal_value(self, data):
# We're lenient with allowing basic numerics to be coerced into strings,
# but other types should fail. Eg. unclear if booleans should represent as `true` or `True`,
# and composites such as lists are likely user error.
if isinstance(data, bool) or not isinstance(data, six.string_types + six.integer_types + (float,)):
self.fail('invalid')
value = six.text_type(data) value = six.text_type(data)
return value.strip() if self.trim_whitespace else value return value.strip() if self.trim_whitespace else value
@ -1016,7 +1022,8 @@ class DecimalField(Field):
return value return value
context = decimal.getcontext().copy() context = decimal.getcontext().copy()
context.prec = self.max_digits if self.max_digits is not None:
context.prec = self.max_digits
return value.quantize( return value.quantize(
decimal.Decimal('.1') ** self.decimal_places, decimal.Decimal('.1') ** self.decimal_places,
context=context context=context

View File

@ -312,6 +312,9 @@ class LimitOffsetPagination(BasePagination):
self.request = request self.request = request
if self.count > self.limit and self.template is not None: if self.count > self.limit and self.template is not None:
self.display_page_controls = True self.display_page_controls = True
if self.count == 0 or self.offset > self.count:
return []
return list(queryset[self.offset:self.offset + self.limit]) return list(queryset[self.offset:self.offset + self.limit])
def get_paginated_response(self, data): def get_paginated_response(self, data):

View File

@ -156,29 +156,35 @@ class RelatedField(Field):
# Standard case, return the object instance. # Standard case, return the object instance.
return get_attribute(instance, self.source_attrs) return get_attribute(instance, self.source_attrs)
@property def get_choices(self, cutoff=None):
def choices(self):
queryset = self.get_queryset() queryset = self.get_queryset()
if queryset is None: if queryset is None:
# Ensure that field.choices returns something sensible # Ensure that field.choices returns something sensible
# even when accessed with a read-only field. # even when accessed with a read-only field.
return {} return {}
if cutoff is not None:
queryset = queryset[:cutoff]
return OrderedDict([ return OrderedDict([
( (
six.text_type(self.to_representation(item)), self.to_representation(item),
self.display_value(item) self.display_value(item)
) )
for item in queryset for item in queryset
]) ])
@property
def choices(self):
return self.get_choices()
@property @property
def grouped_choices(self): def grouped_choices(self):
return self.choices return self.choices
def iter_options(self): def iter_options(self):
return iter_options( return iter_options(
self.grouped_choices, self.get_choices(cutoff=self.html_cutoff),
cutoff=self.html_cutoff, cutoff=self.html_cutoff,
cutoff_text=self.html_cutoff_text cutoff_text=self.html_cutoff_text
) )
@ -487,9 +493,12 @@ class ManyRelatedField(Field):
for value in iterable for value in iterable
] ]
def get_choices(self, cutoff=None):
return self.child_relation.get_choices(cutoff)
@property @property
def choices(self): def choices(self):
return self.child_relation.choices return self.get_choices()
@property @property
def grouped_choices(self): def grouped_choices(self):
@ -497,7 +506,7 @@ class ManyRelatedField(Field):
def iter_options(self): def iter_options(self):
return iter_options( return iter_options(
self.grouped_choices, self.get_choices(cutoff=self.html_cutoff),
cutoff=self.html_cutoff, cutoff=self.html_cutoff,
cutoff_text=self.html_cutoff_text cutoff_text=self.html_cutoff_text
) )

View File

@ -4,6 +4,7 @@ from django.conf import settings
from django.contrib.admindocs.views import simplify_regex from django.contrib.admindocs.views import simplify_regex
from django.core.urlresolvers import RegexURLPattern, RegexURLResolver from django.core.urlresolvers import RegexURLPattern, RegexURLResolver
from django.utils import six from django.utils import six
from django.utils.encoding import force_text
from rest_framework import exceptions, serializers from rest_framework import exceptions, serializers
from rest_framework.compat import coreapi, uritemplate, urlparse from rest_framework.compat import coreapi, uritemplate, urlparse
@ -30,24 +31,6 @@ def is_api_view(callback):
return (cls is not None) and issubclass(cls, APIView) return (cls is not None) and issubclass(cls, APIView)
def insert_into(target, keys, item):
"""
Insert `item` into the nested dictionary `target`.
For example:
target = {}
insert_into(target, ('users', 'list'), Link(...))
insert_into(target, ('users', 'detail'), Link(...))
assert target == {'users': {'list': Link(...), 'detail': Link(...)}}
"""
for key in keys[:1]:
if key not in target:
target[key] = {}
target = target[key]
target[keys[-1]] = item
class SchemaGenerator(object): class SchemaGenerator(object):
default_mapping = { default_mapping = {
'get': 'read', 'get': 'read',
@ -65,46 +48,57 @@ class SchemaGenerator(object):
urls = import_module(urlconf) urls = import_module(urlconf)
else: else:
urls = urlconf urls = urlconf
patterns = urls.urlpatterns self.patterns = urls.urlpatterns
elif patterns is None and urlconf is None: elif patterns is None and urlconf is None:
urls = import_module(settings.ROOT_URLCONF) urls = import_module(settings.ROOT_URLCONF)
patterns = urls.urlpatterns self.patterns = urls.urlpatterns
else:
self.patterns = patterns
if url and not url.endswith('/'): if url and not url.endswith('/'):
url += '/' url += '/'
self.title = title self.title = title
self.url = url self.url = url
self.endpoints = self.get_api_endpoints(patterns) self.endpoints = None
def get_schema(self, request=None): def get_schema(self, request=None):
if request is None: if self.endpoints is None:
endpoints = self.endpoints self.endpoints = self.get_api_endpoints(self.patterns)
else:
# Filter the list of endpoints to only include those that links = []
# the user has permission on. for path, method, category, action, callback in self.endpoints:
endpoints = [] view = self.get_view(callback)
for key, link, callback in self.endpoints: view.args = ()
method = link.action.upper() view.kwargs = {}
view = self.get_view(callback) view.format_kwarg = None
if request is not None:
view.request = clone_request(request, method) view.request = clone_request(request, method)
view.format_kwarg = None
try: try:
view.check_permissions(view.request) view.check_permissions(view.request)
except exceptions.APIException: except exceptions.APIException:
pass continue
else: else:
endpoints.append((key, link, callback)) view.request = None
if not endpoints: link = self.get_link(path, method, callback, view)
links.append((category, action, link))
if not links:
return None return None
# Generate the schema content structure, from the endpoints. # Generate the schema content structure, eg:
# ('users', 'list'), Link -> {'users': {'list': Link()}} # {'users': {'list': Link()}}
content = {} content = {}
for key, link, callback in endpoints: for category, action, link in links:
insert_into(content, key, link) if category is None:
content[action] = link
elif category in content:
content[category][action] = link
else:
content[category] = {action: link}
# Return the schema document. # Return the schema document.
return coreapi.Document(title=self.title, content=content, url=self.url) return coreapi.Document(title=self.title, content=content, url=self.url)
@ -122,9 +116,8 @@ class SchemaGenerator(object):
callback = pattern.callback callback = pattern.callback
if self.should_include_endpoint(path, callback): if self.should_include_endpoint(path, callback):
for method in self.get_allowed_methods(callback): for method in self.get_allowed_methods(callback):
key = self.get_key(path, method, callback) action = self.get_action(path, method, callback)
link = self.get_link(path, method, callback) endpoint = (path, method, action, callback)
endpoint = (key, link, callback)
api_endpoints.append(endpoint) api_endpoints.append(endpoint)
elif isinstance(pattern, RegexURLResolver): elif isinstance(pattern, RegexURLResolver):
@ -134,7 +127,21 @@ class SchemaGenerator(object):
) )
api_endpoints.extend(nested_endpoints) api_endpoints.extend(nested_endpoints)
return api_endpoints return self.add_categories(api_endpoints)
def add_categories(self, api_endpoints):
"""
(path, method, action, callback) -> (path, method, category, action, callback)
"""
# Determine the top level categories for the schema content,
# based on the URLs of the endpoints. Eg `set(['users', 'organisations'])`
paths = [endpoint[0] for endpoint in api_endpoints]
categories = self.get_categories(paths)
return [
(path, method, self.get_category(categories, path), action, callback)
for (path, method, action, callback) in api_endpoints
]
def get_view(self, callback): def get_view(self, callback):
""" """
@ -181,32 +188,45 @@ class SchemaGenerator(object):
view.allowed_methods if method not in ('OPTIONS', 'HEAD') view.allowed_methods if method not in ('OPTIONS', 'HEAD')
] ]
def get_key(self, path, method, callback): def get_action(self, path, method, callback):
""" """
Return a tuple of strings, indicating the identity to use for a Return a description action string for the endpoint, eg. 'list'.
given endpoint. eg. ('users', 'list').
""" """
category = None
for item in path.strip('/').split('/'):
if '{' in item:
break
category = item
actions = getattr(callback, 'actions', self.default_mapping) actions = getattr(callback, 'actions', self.default_mapping)
action = actions[method.lower()] return actions[method.lower()]
if category: def get_categories(self, paths):
return (category, action) categories = set()
return (action,) split_paths = set([
tuple(path.split("{")[0].strip('/').split('/'))
for path in paths
])
while split_paths:
for split_path in list(split_paths):
if len(split_path) == 0:
split_paths.remove(split_path)
elif len(split_path) == 1:
categories.add(split_path[0])
split_paths.remove(split_path)
elif split_path[0] in categories:
split_paths.remove(split_path)
return categories
def get_category(self, categories, path):
path_components = path.split("{")[0].strip('/').split('/')
for path_component in path_components:
if path_component in categories:
return path_component
return None
# Methods for generating each individual `Link` instance... # Methods for generating each individual `Link` instance...
def get_link(self, path, method, callback): def get_link(self, path, method, callback, view):
""" """
Return a `coreapi.Link` instance for the given endpoint. Return a `coreapi.Link` instance for the given endpoint.
""" """
view = self.get_view(callback)
fields = self.get_path_fields(path, method, callback, view) fields = self.get_path_fields(path, method, callback, view)
fields += self.get_serializer_fields(path, method, callback, view) fields += self.get_serializer_fields(path, method, callback, view)
fields += self.get_pagination_fields(path, method, callback, view) fields += self.get_pagination_fields(path, method, callback, view)
@ -269,25 +289,29 @@ class SchemaGenerator(object):
if method not in ('PUT', 'PATCH', 'POST'): if method not in ('PUT', 'PATCH', 'POST'):
return [] return []
if not hasattr(view, 'get_serializer_class'): if not hasattr(view, 'get_serializer'):
return [] return []
fields = [] serializer = view.get_serializer()
serializer_class = view.get_serializer_class()
serializer = serializer_class()
if isinstance(serializer, serializers.ListSerializer): if isinstance(serializer, serializers.ListSerializer):
return coreapi.Field(name='data', location='body', required=True) return [coreapi.Field(name='data', location='body', required=True)]
if not isinstance(serializer, serializers.Serializer): if not isinstance(serializer, serializers.Serializer):
return [] return []
fields = []
for field in serializer.fields.values(): for field in serializer.fields.values():
if field.read_only: if field.read_only:
continue continue
required = field.required and method != 'PATCH' required = field.required and method != 'PATCH'
field = coreapi.Field(name=field.source, location='form', required=required) description = force_text(field.help_text) if field.help_text else ''
field = coreapi.Field(
name=field.source,
location='form',
required=required,
description=description
)
fields.append(field) fields.append(field)
return fields return fields

View File

@ -1,3 +1,5 @@
{% load rest_framework %}
<div class="form-group"> <div class="form-group">
{% if field.label %} {% if field.label %}
<label class="col-sm-2 control-label {% if style.hide_label %}sr-only{% endif %}"> <label class="col-sm-2 control-label {% if style.hide_label %}sr-only{% endif %}">
@ -9,7 +11,7 @@
{% if style.inline %} {% if style.inline %}
{% for key, text in field.choices.items %} {% for key, text in field.choices.items %}
<label class="checkbox-inline"> <label class="checkbox-inline">
<input type="checkbox" name="{{ field.name }}" value="{{ key }}" {% if key in field.value %}checked{% endif %}> <input type="checkbox" name="{{ field.name }}" value="{{ key }}" {% if key|as_string in field.value|as_list_of_strings %}checked{% endif %}>
{{ text }} {{ text }}
</label> </label>
{% endfor %} {% endfor %}
@ -17,7 +19,7 @@
{% for key, text in field.choices.items %} {% for key, text in field.choices.items %}
<div class="checkbox"> <div class="checkbox">
<label> <label>
<input type="checkbox" name="{{ field.name }}" value="{{ key }}" {% if key in field.value %}checked{% endif %}> <input type="checkbox" name="{{ field.name }}" value="{{ key }}" {% if key|as_string in field.value|as_list_of_strings %}checked{% endif %}>
{{ text }} {{ text }}
</label> </label>
</div> </div>

View File

@ -1,4 +1,6 @@
{% load i18n %} {% load i18n %}
{% load rest_framework %}
{% trans "None" as none_choice %} {% trans "None" as none_choice %}
<div class="form-group"> <div class="form-group">
@ -19,7 +21,7 @@
{% for key, text in field.choices.items %} {% for key, text in field.choices.items %}
<label class="radio-inline"> <label class="radio-inline">
<input type="radio" name="{{ field.name }}" value="{{ key }}" {% if key == field.value %}checked{% endif %} /> <input type="radio" name="{{ field.name }}" value="{{ key }}" {% if key|as_string == field.value|as_string %}checked{% endif %} />
{{ text }} {{ text }}
</label> </label>
{% endfor %} {% endfor %}
@ -35,7 +37,7 @@
{% for key, text in field.choices.items %} {% for key, text in field.choices.items %}
<div class="radio"> <div class="radio">
<label> <label>
<input type="radio" name="{{ field.name }}" value="{{ key }}" {% if key == field.value %}checked{% endif %} /> <input type="radio" name="{{ field.name }}" value="{{ key }}" {% if key|as_string == field.value|as_string %}checked{% endif %} />
{{ text }} {{ text }}
</label> </label>
</div> </div>

View File

@ -1,3 +1,5 @@
{% load rest_framework %}
<div class="form-group"> <div class="form-group">
{% if field.label %} {% if field.label %}
<label class="col-sm-2 control-label {% if style.hide_label %}sr-only{% endif %}"> <label class="col-sm-2 control-label {% if style.hide_label %}sr-only{% endif %}">
@ -16,7 +18,7 @@
{% elif select.end_option_group %} {% elif select.end_option_group %}
</optgroup> </optgroup>
{% else %} {% else %}
<option value="{{ select.value }}" {% if select.value == field.value %}selected{% endif %} {% if select.disabled %}disabled{% endif %}>{{ select.display_text }}</option> <option value="{{ select.value }}" {% if select.value|as_string == field.value|as_string %}selected{% endif %} {% if select.disabled %}disabled{% endif %}>{{ select.display_text }}</option>
{% endif %} {% endif %}
{% endfor %} {% endfor %}
</select> </select>

View File

@ -1,4 +1,6 @@
{% load i18n %} {% load i18n %}
{% load rest_framework %}
{% trans "No items to select." as no_items %} {% trans "No items to select." as no_items %}
<div class="form-group"> <div class="form-group">
@ -16,7 +18,7 @@
{% elif select.end_option_group %} {% elif select.end_option_group %}
</optgroup> </optgroup>
{% else %} {% else %}
<option value="{{ select.value }}" {% if select.value in field.value %}selected{% endif %} {% if select.disabled %}disabled{% endif %}>{{ select.display_text }}</option> <option value="{{ select.value }}" {% if select.value|as_string in field.value|as_list_of_strings %}selected{% endif %} {% if select.disabled %}disabled{% endif %}>{{ select.display_text }}</option>
{% endif %} {% endif %}
{% empty %} {% empty %}
<option>{{ no_items }}</option> <option>{{ no_items }}</option>

View File

@ -1,3 +1,5 @@
{% load rest_framework %}
<div class="form-group {% if field.errors %}has-error{% endif %}"> <div class="form-group {% if field.errors %}has-error{% endif %}">
{% if field.label %} {% if field.label %}
<label class="sr-only">{{ field.label }}</label> <label class="sr-only">{{ field.label }}</label>
@ -6,7 +8,7 @@
{% for key, text in field.choices.items %} {% for key, text in field.choices.items %}
<div class="checkbox"> <div class="checkbox">
<label> <label>
<input type="checkbox" name="{{ field.name }}" value="{{ key }}" {% if key in field.value %}checked{% endif %}> <input type="checkbox" name="{{ field.name }}" value="{{ key }}" {% if key|as_string in field.value|as_list_of_strings %}checked{% endif %}>
{{ text }} {{ text }}
</label> </label>
</div> </div>

View File

@ -1,4 +1,5 @@
{% load i18n %} {% load i18n %}
{% load rest_framework %}
{% trans "None" as none_choice %} {% trans "None" as none_choice %}
<div class="form-group {% if field.errors %}has-error{% endif %}"> <div class="form-group {% if field.errors %}has-error{% endif %}">
@ -20,7 +21,7 @@
{% for key, text in field.choices.items %} {% for key, text in field.choices.items %}
<div class="radio"> <div class="radio">
<label> <label>
<input type="radio" name="{{ field.name }}" value="{{ key }}" {% if key == field.value %}checked{% endif %}> <input type="radio" name="{{ field.name }}" value="{{ key }}" {% if key|as_string == field.value|as_string %}checked{% endif %}>
{{ text }} {{ text }}
</label> </label>
</div> </div>

View File

@ -1,3 +1,5 @@
{% load rest_framework %}
<div class="form-group {% if field.errors %}has-error{% endif %}"> <div class="form-group {% if field.errors %}has-error{% endif %}">
{% if field.label %} {% if field.label %}
<label class="sr-only"> <label class="sr-only">
@ -15,7 +17,7 @@
{% elif select.end_option_group %} {% elif select.end_option_group %}
</optgroup> </optgroup>
{% else %} {% else %}
<option value="{{ select.value }}" {% if select.value == field.value %}selected{% endif %} {% if select.disabled %}disabled{% endif %}>{{ select.display_text }}</option> <option value="{{ select.value }}" {% if select.value|as_string == field.value|as_string %}selected{% endif %} {% if select.disabled %}disabled{% endif %}>{{ select.display_text }}</option>
{% endif %} {% endif %}
{% endfor %} {% endfor %}
</select> </select>

View File

@ -1,4 +1,5 @@
{% load i18n %} {% load i18n %}
{% load rest_framework %}
{% trans "No items to select." as no_items %} {% trans "No items to select." as no_items %}
<div class="form-group {% if field.errors %}has-error{% endif %}"> <div class="form-group {% if field.errors %}has-error{% endif %}">
@ -15,7 +16,7 @@
{% elif select.end_option_group %} {% elif select.end_option_group %}
</optgroup> </optgroup>
{% else %} {% else %}
<option value="{{ select.value }}" {% if select.value in field.value %}selected{% endif %} {% if select.disabled %}disabled{% endif %}>{{ select.display_text }}</option> <option value="{{ select.value }}" {% if select.value|as_string in field.value|as_list_of_strings %}selected{% endif %} {% if select.disabled %}disabled{% endif %}>{{ select.display_text }}</option>
{% endif %} {% endif %}
{% empty %} {% empty %}
<option>{{ no_items }}</option> <option>{{ no_items }}</option>

View File

@ -1,3 +1,5 @@
{% load rest_framework %}
<div class="form-group {% if field.errors %}has-error{% endif %}"> <div class="form-group {% if field.errors %}has-error{% endif %}">
{% if field.label %} {% if field.label %}
<label {% if style.hide_label %}class="sr-only"{% endif %}>{{ field.label }}</label> <label {% if style.hide_label %}class="sr-only"{% endif %}>{{ field.label }}</label>
@ -7,7 +9,7 @@
<div> <div>
{% for key, text in field.choices.items %} {% for key, text in field.choices.items %}
<label class="checkbox-inline"> <label class="checkbox-inline">
<input type="checkbox" name="{{ field.name }}" value="{{ key }}" {% if key in field.value %}checked{% endif %}> <input type="checkbox" name="{{ field.name }}" value="{{ key }}" {% if key|as_string in field.value|as_list_of_stringsg %}checked{% endif %}>
{{ text }} {{ text }}
</label> </label>
{% endfor %} {% endfor %}
@ -16,7 +18,7 @@
{% for key, text in field.choices.items %} {% for key, text in field.choices.items %}
<div class="checkbox"> <div class="checkbox">
<label> <label>
<input type="checkbox" name="{{ field.name }}" value="{{ key }}" {% if key in field.value %}checked{% endif %}> <input type="checkbox" name="{{ field.name }}" value="{{ key }}" {% if key|as_string in field.value|as_list_of_stringsg %}checked{% endif %}>
{{ text }} {{ text }}
</label> </label>
</div> </div>

View File

@ -1,4 +1,5 @@
{% load i18n %} {% load i18n %}
{% load rest_framework %}
{% trans "None" as none_choice %} {% trans "None" as none_choice %}
<div class="form-group {% if field.errors %}has-error{% endif %}"> <div class="form-group {% if field.errors %}has-error{% endif %}">
@ -19,7 +20,7 @@
{% for key, text in field.choices.items %} {% for key, text in field.choices.items %}
<label class="radio-inline"> <label class="radio-inline">
<input type="radio" name="{{ field.name }}" value="{{ key }}" {% if key == field.value %}checked{% endif %}> <input type="radio" name="{{ field.name }}" value="{{ key }}" {% if key|as_string == field.value|as_string %}checked{% endif %}>
{{ text }} {{ text }}
</label> </label>
{% endfor %} {% endfor %}
@ -37,7 +38,7 @@
{% for key, text in field.choices.items %} {% for key, text in field.choices.items %}
<div class="radio"> <div class="radio">
<label> <label>
<input type="radio" name="{{ field.name }}" value="{{ key }}" {% if key == field.value %}checked{% endif %}> <input type="radio" name="{{ field.name }}" value="{{ key }}" {% if key|as_string == field.value|as_string %}checked{% endif %}>
{{ text }} {{ text }}
</label> </label>
</div> </div>

View File

@ -1,3 +1,5 @@
{% load rest_framework %}
<div class="form-group {% if field.errors %}has-error{% endif %}"> <div class="form-group {% if field.errors %}has-error{% endif %}">
{% if field.label %} {% if field.label %}
<label {% if style.hide_label %}class="sr-only"{% endif %}> <label {% if style.hide_label %}class="sr-only"{% endif %}>
@ -15,7 +17,7 @@
{% elif select.end_option_group %} {% elif select.end_option_group %}
</optgroup> </optgroup>
{% else %} {% else %}
<option value="{{ select.value }}" {% if select.value == field.value %}selected{% endif %} {% if select.disabled %}disabled{% endif %}>{{ select.display_text }}</option> <option value="{{ select.value }}" {% if select.value|as_string == field.value|as_string %}selected{% endif %} {% if select.disabled %}disabled{% endif %}>{{ select.display_text }}</option>
{% endif %} {% endif %}
{% endfor %} {% endfor %}
</select> </select>

View File

@ -1,4 +1,5 @@
{% load i18n %} {% load i18n %}
{% load rest_framework %}
{% trans "No items to select." as no_items %} {% trans "No items to select." as no_items %}
<div class="form-group {% if field.errors %}has-error{% endif %}"> <div class="form-group {% if field.errors %}has-error{% endif %}">
@ -15,7 +16,7 @@
{% elif select.end_option_group %} {% elif select.end_option_group %}
</optgroup> </optgroup>
{% else %} {% else %}
<option value="{{ select.value }}" {% if select.value in field.value %}selected{% endif %} {% if select.disabled %}disabled{% endif %}>{{ select.display_text }}</option> <option value="{{ select.value }}" {% if select.value|as_string in field.value|as_list_of_strings %}selected{% endif %} {% if select.disabled %}disabled{% endif %}>{{ select.display_text }}</option>
{% endif %} {% endif %}
{% empty %} {% empty %}
<option>{{ no_items }}</option> <option>{{ no_items }}</option>

View File

@ -89,6 +89,21 @@ def add_query_param(request, key, val):
return escape(replace_query_param(uri, key, val)) return escape(replace_query_param(uri, key, val))
@register.filter
def as_string(value):
if value is None:
return ''
return '%s' % value
@register.filter
def as_list_of_strings(value):
return [
'' if (item is None) else ('%s' % item)
for item in value
]
@register.filter @register.filter
def add_class(value, css_class): def add_class(value, css_class):
""" """

View File

@ -78,7 +78,7 @@ class BoundField(object):
)) ))
def as_form_field(self): def as_form_field(self):
value = '' if (self.value is None or self.value is False) else force_text(self.value) value = '' if (self.value is None or self.value is False) else self.value
return self.__class__(self._field, value, self.errors, self._prefix) return self.__class__(self._field, value, self.errors, self._prefix)

View File

@ -30,7 +30,8 @@ class BaseVersioning(object):
def is_allowed_version(self, version): def is_allowed_version(self, version):
if not self.allowed_versions: if not self.allowed_versions:
return True return True
return (version == self.default_version) or (version in self.allowed_versions) return ((version is not None and version == self.default_version) or
(version in self.allowed_versions))
class AcceptHeaderVersioning(BaseVersioning): class AcceptHeaderVersioning(BaseVersioning):

View File

@ -3,13 +3,17 @@ Provides an APIView class that is the base of all views in REST framework.
""" """
from __future__ import unicode_literals from __future__ import unicode_literals
import sys
from django.conf import settings
from django.core.exceptions import PermissionDenied from django.core.exceptions import PermissionDenied
from django.db import models from django.db import models
from django.http import Http404 from django.http import Http404
from django.http.response import HttpResponseBase from django.http.response import HttpResponse, HttpResponseBase
from django.utils import six from django.utils import six
from django.utils.encoding import smart_text from django.utils.encoding import smart_text
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from django.views import debug
from django.views.decorators.csrf import csrf_exempt from django.views.decorators.csrf import csrf_exempt
from django.views.generic import View from django.views.generic import View
@ -91,7 +95,11 @@ def exception_handler(exc, context):
set_rollback() set_rollback()
return Response(data, status=status.HTTP_403_FORBIDDEN) return Response(data, status=status.HTTP_403_FORBIDDEN)
# Note: Unhandled exceptions will raise a 500 error. # throw django's error page if debug is True
if settings.DEBUG:
exception_reporter = debug.ExceptionReporter(context.get('request'), *sys.exc_info())
return HttpResponse(exception_reporter.get_traceback_html(), status=500)
return None return None

View File

@ -55,6 +55,30 @@ class TestSimpleBoundField:
assert serializer['bool_field'].as_form_field().value == '' assert serializer['bool_field'].as_form_field().value == ''
assert serializer['null_field'].as_form_field().value == '' assert serializer['null_field'].as_form_field().value == ''
def test_rendering_boolean_field(self):
from rest_framework.renderers import HTMLFormRenderer
class ExampleSerializer(serializers.Serializer):
bool_field = serializers.BooleanField(
style={'base_template': 'checkbox.html', 'template_pack': 'rest_framework/vertical'})
serializer = ExampleSerializer(data={'bool_field': True})
assert serializer.is_valid()
renderer = HTMLFormRenderer()
rendered = renderer.render_field(serializer['bool_field'], {})
expected_packed = (
'<divclass="form-group">'
'<divclass="checkbox">'
'<label>'
'<inputtype="checkbox"name="bool_field"value="true"checked>'
'Boolfield'
'</label>'
'</div>'
'</div>'
)
rendered_packed = ''.join(rendered.split())
assert rendered_packed == expected_packed
class TestNestedBoundField: class TestNestedBoundField:
def test_nested_empty_bound_field(self): def test_nested_empty_bound_field(self):

View File

@ -535,6 +535,8 @@ class TestCharField(FieldValues):
'abc': 'abc' 'abc': 'abc'
} }
invalid_inputs = { invalid_inputs = {
(): ['Not a valid string.'],
True: ['Not a valid string.'],
'': ['This field may not be blank.'] '': ['This field may not be blank.']
} }
outputs = { outputs = {
@ -876,6 +878,18 @@ class TestMinMaxDecimalField(FieldValues):
) )
class TestNoMaxDigitsDecimalField(FieldValues):
field = serializers.DecimalField(
max_value=100, min_value=0,
decimal_places=2, max_digits=None
)
valid_inputs = {
'10': Decimal('10.00')
}
invalid_inputs = {}
outputs = {}
class TestNoStringCoercionDecimalField(FieldValues): class TestNoStringCoercionDecimalField(FieldValues):
""" """
Output values for `DecimalField` with `coerce_to_string=False`. Output values for `DecimalField` with `coerce_to_string=False`.

View File

@ -614,7 +614,7 @@ class TestRelationalFieldDisplayValue(TestCase):
fields = '__all__' fields = '__all__'
serializer = TestSerializer() serializer = TestSerializer()
expected = OrderedDict([('1', 'Red Color'), ('2', 'Yellow Color'), ('3', 'Green Color')]) expected = OrderedDict([(1, 'Red Color'), (2, 'Yellow Color'), (3, 'Green Color')])
self.assertEqual(serializer.fields['color'].choices, expected) self.assertEqual(serializer.fields['color'].choices, expected)
def test_custom_display_value(self): def test_custom_display_value(self):
@ -630,7 +630,7 @@ class TestRelationalFieldDisplayValue(TestCase):
fields = '__all__' fields = '__all__'
serializer = TestSerializer() serializer = TestSerializer()
expected = OrderedDict([('1', 'My Red Color'), ('2', 'My Yellow Color'), ('3', 'My Green Color')]) expected = OrderedDict([(1, 'My Red Color'), (2, 'My Yellow Color'), (3, 'My Green Color')])
self.assertEqual(serializer.fields['color'].choices, expected) self.assertEqual(serializer.fields['color'].choices, expected)

View File

@ -481,3 +481,90 @@ class TestHTMLFormRenderer(TestCase):
result = renderer.render(self.serializer.data, None, {}) result = renderer.render(self.serializer.data, None, {})
self.assertIsInstance(result, SafeText) self.assertIsInstance(result, SafeText)
class TestChoiceFieldHTMLFormRenderer(TestCase):
"""
Test rendering ChoiceField with HTMLFormRenderer.
"""
def setUp(self):
choices = ((1, 'Option1'), (2, 'Option2'), (12, 'Option12'))
class TestSerializer(serializers.Serializer):
test_field = serializers.ChoiceField(choices=choices,
initial=2)
self.TestSerializer = TestSerializer
self.renderer = HTMLFormRenderer()
def test_render_initial_option(self):
serializer = self.TestSerializer()
result = self.renderer.render(serializer.data)
self.assertIsInstance(result, SafeText)
self.assertInHTML('<option value="2" selected>Option2</option>',
result)
self.assertInHTML('<option value="1">Option1</option>', result)
self.assertInHTML('<option value="12">Option12</option>', result)
def test_render_selected_option(self):
serializer = self.TestSerializer(data={'test_field': '12'})
serializer.is_valid()
result = self.renderer.render(serializer.data)
self.assertIsInstance(result, SafeText)
self.assertInHTML('<option value="12" selected>Option12</option>',
result)
self.assertInHTML('<option value="1">Option1</option>', result)
self.assertInHTML('<option value="2">Option2</option>', result)
class TestMultipleChoiceFieldHTMLFormRenderer(TestCase):
"""
Test rendering MultipleChoiceField with HTMLFormRenderer.
"""
def setUp(self):
self.renderer = HTMLFormRenderer()
def test_render_selected_option_with_string_option_ids(self):
choices = (('1', 'Option1'), ('2', 'Option2'), ('12', 'Option12'),
('}', 'OptionBrace'))
class TestSerializer(serializers.Serializer):
test_field = serializers.MultipleChoiceField(choices=choices)
serializer = TestSerializer(data={'test_field': ['12']})
serializer.is_valid()
result = self.renderer.render(serializer.data)
self.assertIsInstance(result, SafeText)
self.assertInHTML('<option value="12" selected>Option12</option>',
result)
self.assertInHTML('<option value="1">Option1</option>', result)
self.assertInHTML('<option value="2">Option2</option>', result)
self.assertInHTML('<option value="}">OptionBrace</option>', result)
def test_render_selected_option_with_integer_option_ids(self):
choices = ((1, 'Option1'), (2, 'Option2'), (12, 'Option12'))
class TestSerializer(serializers.Serializer):
test_field = serializers.MultipleChoiceField(choices=choices)
serializer = TestSerializer(data={'test_field': ['12']})
serializer.is_valid()
result = self.renderer.render(serializer.data)
self.assertIsInstance(result, SafeText)
self.assertInHTML('<option value="12" selected>Option12</option>',
result)
self.assertInHTML('<option value="1">Option1</option>', result)
self.assertInHTML('<option value="2">Option2</option>', result)

View File

@ -5,7 +5,7 @@ from django.test import TestCase, override_settings
from rest_framework import filters, pagination, permissions, serializers from rest_framework import filters, pagination, permissions, serializers
from rest_framework.compat import coreapi from rest_framework.compat import coreapi
from rest_framework.decorators import detail_route from rest_framework.decorators import detail_route, list_route
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.routers import DefaultRouter from rest_framework.routers import DefaultRouter
from rest_framework.schemas import SchemaGenerator from rest_framework.schemas import SchemaGenerator
@ -24,7 +24,7 @@ class ExamplePagination(pagination.PageNumberPagination):
class ExampleSerializer(serializers.Serializer): class ExampleSerializer(serializers.Serializer):
a = serializers.CharField(required=True) a = serializers.CharField(required=True, help_text='A field description')
b = serializers.CharField(required=False) b = serializers.CharField(required=False)
@ -53,6 +53,14 @@ class ExampleViewSet(ModelViewSet):
def forbidden_action(self, request, pk): def forbidden_action(self, request, pk):
return super(ExampleSerializer, self).update(self, request) return super(ExampleSerializer, self).update(self, request)
@list_route()
def custom_list_action(self, request):
return super(ExampleViewSet, self).list(self, request)
def get_serializer(self, *args, **kwargs):
assert self.request
return super(ExampleViewSet, self).get_serializer(*args, **kwargs)
class ExampleView(APIView): class ExampleView(APIView):
permission_classes = [permissions.IsAuthenticatedOrReadOnly] permission_classes = [permissions.IsAuthenticatedOrReadOnly]
@ -94,6 +102,10 @@ class TestRouterGeneratedSchema(TestCase):
coreapi.Field('ordering', required=False, location='query') coreapi.Field('ordering', required=False, location='query')
] ]
), ),
'custom_list_action': coreapi.Link(
url='/example/custom_list_action/',
action='get'
),
'retrieve': coreapi.Link( 'retrieve': coreapi.Link(
url='/example/{pk}/', url='/example/{pk}/',
action='get', action='get',
@ -129,7 +141,7 @@ class TestRouterGeneratedSchema(TestCase):
action='post', action='post',
encoding='application/json', encoding='application/json',
fields=[ fields=[
coreapi.Field('a', required=True, location='form'), coreapi.Field('a', required=True, location='form', description='A field description'),
coreapi.Field('b', required=False, location='form') coreapi.Field('b', required=False, location='form')
] ]
), ),
@ -160,13 +172,17 @@ class TestRouterGeneratedSchema(TestCase):
coreapi.Field('d', required=False, location='form'), coreapi.Field('d', required=False, location='form'),
] ]
), ),
'custom_list_action': coreapi.Link(
url='/example/custom_list_action/',
action='get'
),
'update': coreapi.Link( 'update': coreapi.Link(
url='/example/{pk}/', url='/example/{pk}/',
action='put', action='put',
encoding='application/json', encoding='application/json',
fields=[ fields=[
coreapi.Field('pk', required=True, location='path'), coreapi.Field('pk', required=True, location='path'),
coreapi.Field('a', required=True, location='form'), coreapi.Field('a', required=True, location='form', description='A field description'),
coreapi.Field('b', required=False, location='form') coreapi.Field('b', required=False, location='form')
] ]
), ),
@ -176,7 +192,7 @@ class TestRouterGeneratedSchema(TestCase):
encoding='application/json', encoding='application/json',
fields=[ fields=[
coreapi.Field('pk', required=True, location='path'), coreapi.Field('pk', required=True, location='path'),
coreapi.Field('a', required=False, location='form'), coreapi.Field('a', required=False, location='form', description='A field description'),
coreapi.Field('b', required=False, location='form') coreapi.Field('b', required=False, location='form')
] ]
), ),

View File

@ -44,14 +44,34 @@ class ReverseView(APIView):
return Response({'url': reverse('another', request=request)}) return Response({'url': reverse('another', request=request)})
class RequestInvalidVersionView(APIView): class AllowedVersionsView(RequestVersionView):
def determine_version(self, request, *args, **kwargs): def determine_version(self, request, *args, **kwargs):
scheme = self.versioning_class() scheme = self.versioning_class()
scheme.allowed_versions = ('v1', 'v2') scheme.allowed_versions = ('v1', 'v2')
return (scheme.determine_version(request, *args, **kwargs), scheme) return (scheme.determine_version(request, *args, **kwargs), scheme)
def get(self, request, *args, **kwargs):
return Response({'version': request.version}) class AllowedAndDefaultVersionsView(RequestVersionView):
def determine_version(self, request, *args, **kwargs):
scheme = self.versioning_class()
scheme.allowed_versions = ('v1', 'v2')
scheme.default_version = 'v2'
return (scheme.determine_version(request, *args, **kwargs), scheme)
class AllowedWithNoneVersionsView(RequestVersionView):
def determine_version(self, request, *args, **kwargs):
scheme = self.versioning_class()
scheme.allowed_versions = ('v1', 'v2', None)
return (scheme.determine_version(request, *args, **kwargs), scheme)
class AllowedWithNoneAndDefaultVersionsView(RequestVersionView):
def determine_version(self, request, *args, **kwargs):
scheme = self.versioning_class()
scheme.allowed_versions = ('v1', 'v2', None)
scheme.default_version = 'v2'
return (scheme.determine_version(request, *args, **kwargs), scheme)
factory = APIRequestFactory() factory = APIRequestFactory()
@ -219,7 +239,7 @@ class TestURLReversing(URLPatternsTestCase):
class TestInvalidVersion: class TestInvalidVersion:
def test_invalid_query_param_versioning(self): def test_invalid_query_param_versioning(self):
scheme = versioning.QueryParameterVersioning scheme = versioning.QueryParameterVersioning
view = RequestInvalidVersionView.as_view(versioning_class=scheme) view = AllowedVersionsView.as_view(versioning_class=scheme)
request = factory.get('/endpoint/?version=v3') request = factory.get('/endpoint/?version=v3')
response = view(request) response = view(request)
@ -228,7 +248,7 @@ class TestInvalidVersion:
@override_settings(ALLOWED_HOSTS=['*']) @override_settings(ALLOWED_HOSTS=['*'])
def test_invalid_host_name_versioning(self): def test_invalid_host_name_versioning(self):
scheme = versioning.HostNameVersioning scheme = versioning.HostNameVersioning
view = RequestInvalidVersionView.as_view(versioning_class=scheme) view = AllowedVersionsView.as_view(versioning_class=scheme)
request = factory.get('/endpoint/', HTTP_HOST='v3.example.org') request = factory.get('/endpoint/', HTTP_HOST='v3.example.org')
response = view(request) response = view(request)
@ -236,7 +256,7 @@ class TestInvalidVersion:
def test_invalid_accept_header_versioning(self): def test_invalid_accept_header_versioning(self):
scheme = versioning.AcceptHeaderVersioning scheme = versioning.AcceptHeaderVersioning
view = RequestInvalidVersionView.as_view(versioning_class=scheme) view = AllowedVersionsView.as_view(versioning_class=scheme)
request = factory.get('/endpoint/', HTTP_ACCEPT='application/json; version=v3') request = factory.get('/endpoint/', HTTP_ACCEPT='application/json; version=v3')
response = view(request) response = view(request)
@ -244,7 +264,7 @@ class TestInvalidVersion:
def test_invalid_url_path_versioning(self): def test_invalid_url_path_versioning(self):
scheme = versioning.URLPathVersioning scheme = versioning.URLPathVersioning
view = RequestInvalidVersionView.as_view(versioning_class=scheme) view = AllowedVersionsView.as_view(versioning_class=scheme)
request = factory.get('/v3/endpoint/') request = factory.get('/v3/endpoint/')
response = view(request, version='v3') response = view(request, version='v3')
@ -255,7 +275,7 @@ class TestInvalidVersion:
namespace = 'v3' namespace = 'v3'
scheme = versioning.NamespaceVersioning scheme = versioning.NamespaceVersioning
view = RequestInvalidVersionView.as_view(versioning_class=scheme) view = AllowedVersionsView.as_view(versioning_class=scheme)
request = factory.get('/v3/endpoint/') request = factory.get('/v3/endpoint/')
request.resolver_match = FakeResolverMatch request.resolver_match = FakeResolverMatch
@ -263,6 +283,52 @@ class TestInvalidVersion:
assert response.status_code == status.HTTP_404_NOT_FOUND assert response.status_code == status.HTTP_404_NOT_FOUND
class TestAllowedAndDefaultVersion:
def test_missing_without_default(self):
scheme = versioning.AcceptHeaderVersioning
view = AllowedVersionsView.as_view(versioning_class=scheme)
request = factory.get('/endpoint/', HTTP_ACCEPT='application/json')
response = view(request)
assert response.status_code == status.HTTP_406_NOT_ACCEPTABLE
def test_missing_with_default(self):
scheme = versioning.AcceptHeaderVersioning
view = AllowedAndDefaultVersionsView.as_view(versioning_class=scheme)
request = factory.get('/endpoint/', HTTP_ACCEPT='application/json')
response = view(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {'version': 'v2'}
def test_with_default(self):
scheme = versioning.AcceptHeaderVersioning
view = AllowedAndDefaultVersionsView.as_view(versioning_class=scheme)
request = factory.get('/endpoint/',
HTTP_ACCEPT='application/json; version=v2')
response = view(request)
assert response.status_code == status.HTTP_200_OK
def test_missing_without_default_but_none_allowed(self):
scheme = versioning.AcceptHeaderVersioning
view = AllowedWithNoneVersionsView.as_view(versioning_class=scheme)
request = factory.get('/endpoint/', HTTP_ACCEPT='application/json')
response = view(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {'version': None}
def test_missing_with_default_and_none_allowed(self):
scheme = versioning.AcceptHeaderVersioning
view = AllowedWithNoneAndDefaultVersionsView.as_view(versioning_class=scheme)
request = factory.get('/endpoint/', HTTP_ACCEPT='application/json')
response = view(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {'version': 'v2'}
class TestHyperlinkedRelatedField(URLPatternsTestCase): class TestHyperlinkedRelatedField(URLPatternsTestCase):
included = [ included = [
url(r'^namespaced/(?P<pk>\d+)/$', dummy_pk_view, name='namespaced'), url(r'^namespaced/(?P<pk>\d+)/$', dummy_pk_view, name='namespaced'),