diff --git a/docs/api-guide/generic-views.md b/docs/api-guide/generic-views.md
index b1c4e65ad..49be0cae8 100755
--- a/docs/api-guide/generic-views.md
+++ b/docs/api-guide/generic-views.md
@@ -19,8 +19,8 @@ Typically when using the generic views, you'll override the view, and set severa
from django.contrib.auth.models import User
from myapp.serializers import UserSerializer
- from rest_framework import generics
- from rest_framework.permissions import IsAdminUser
+ from rest_framework import generics
+ from rest_framework.permissions import IsAdminUser
class UserList(generics.ListCreateAPIView):
queryset = User.objects.all()
@@ -212,8 +212,6 @@ Provides a `.list(request, *args, **kwargs)` method, that implements listing a q
If the queryset is populated, this returns a `200 OK` response, with a serialized representation of the queryset as the body of the response. The response data may optionally be paginated.
-If the queryset is empty this returns a `200 OK` response, unless the `.allow_empty` attribute on the view is set to `False`, in which case it will return a `404 Not Found`.
-
## CreateModelMixin
Provides a `.create(request, *args, **kwargs)` method, that implements creating and saving a new model instance.
diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py
index 449ba0a29..d28d6e22a 100644
--- a/rest_framework/decorators.py
+++ b/rest_framework/decorators.py
@@ -10,7 +10,6 @@ from __future__ import unicode_literals
from django.utils import six
from rest_framework.views import APIView
import types
-import warnings
def api_view(http_method_names):
@@ -130,37 +129,3 @@ def list_route(methods=['get'], **kwargs):
func.kwargs = kwargs
return func
return decorator
-
-
-# These are now pending deprecation, in favor of `detail_route` and `list_route`.
-
-def link(**kwargs):
- """
- Used to mark a method on a ViewSet that should be routed for detail GET requests.
- """
- msg = 'link is pending deprecation. Use detail_route instead.'
- warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
-
- def decorator(func):
- func.bind_to_methods = ['get']
- func.detail = True
- func.kwargs = kwargs
- return func
-
- return decorator
-
-
-def action(methods=['post'], **kwargs):
- """
- Used to mark a method on a ViewSet that should be routed for detail POST requests.
- """
- msg = 'action is pending deprecation. Use detail_route instead.'
- warnings.warn(msg, PendingDeprecationWarning, stacklevel=2)
-
- def decorator(func):
- func.bind_to_methods = methods
- func.detail = True
- func.kwargs = kwargs
- return func
-
- return decorator
diff --git a/rest_framework/fields.py b/rest_framework/fields.py
index 9d707c9b5..a83bf94c4 100644
--- a/rest_framework/fields.py
+++ b/rest_framework/fields.py
@@ -1,1038 +1,308 @@
-"""
-Serializer fields perform validation on incoming data.
-
-They are very similar to Django's form fields.
-"""
-from __future__ import unicode_literals
-
-import copy
-import datetime
-import inspect
-import re
-import warnings
-from decimal import Decimal, DecimalException
-from django import forms
-from django.core import validators
-from django.core.exceptions import ValidationError
-from django.conf import settings
-from django.db.models.fields import BLANK_CHOICE_DASH
-from django.http import QueryDict
-from django.forms import widgets
-from django.utils import six, timezone
-from django.utils.encoding import is_protected_type
-from django.utils.translation import ugettext_lazy as _
-from django.utils.datastructures import SortedDict
-from django.utils.dateparse import parse_date, parse_datetime, parse_time
-from rest_framework import ISO_8601
-from rest_framework.compat import (
- BytesIO, smart_text,
- force_text, is_non_str_iterable
-)
-from rest_framework.settings import api_settings
+from rest_framework.utils import html
-def is_simple_callable(obj):
+class empty:
"""
- True if the object is a callable that takes no arguments.
+ This class is used to represent no data being provided for a given input
+ or output value.
+
+ It is required because `None` may be a valid input or output value.
"""
- function = inspect.isfunction(obj)
- method = inspect.ismethod(obj)
-
- if not (function or method):
- return False
-
- args, _, _, defaults = inspect.getargspec(obj)
- len_args = len(args) if function else len(args) - 1
- len_defaults = len(defaults) if defaults else 0
- return len_args <= len_defaults
+ pass
-def get_component(obj, attr_name):
+def get_attribute(instance, attrs):
"""
- Given an object, and an attribute name,
- return that attribute on the object.
+ Similar to Python's built in `getattr(instance, attr)`,
+ but takes a list of nested attributes, instead of a single attribute.
"""
- if isinstance(obj, dict):
- val = obj.get(attr_name)
- else:
- val = getattr(obj, attr_name)
-
- if is_simple_callable(val):
- return val()
- return val
+ for attr in attrs:
+ instance = getattr(instance, attr)
+ return instance
-def readable_datetime_formats(formats):
- format = ', '.join(formats).replace(
- ISO_8601,
- 'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]'
- )
- return humanize_strptime(format)
-
-
-def readable_date_formats(formats):
- format = ', '.join(formats).replace(ISO_8601, 'YYYY[-MM[-DD]]')
- return humanize_strptime(format)
-
-
-def readable_time_formats(formats):
- format = ', '.join(formats).replace(ISO_8601, 'hh:mm[:ss[.uuuuuu]]')
- return humanize_strptime(format)
-
-
-def humanize_strptime(format_string):
- # Note that we're missing some of the locale specific mappings that
- # don't really make sense.
- mapping = {
- "%Y": "YYYY",
- "%y": "YY",
- "%m": "MM",
- "%b": "[Jan-Dec]",
- "%B": "[January-December]",
- "%d": "DD",
- "%H": "hh",
- "%I": "hh", # Requires '%p' to differentiate from '%H'.
- "%M": "mm",
- "%S": "ss",
- "%f": "uuuuuu",
- "%a": "[Mon-Sun]",
- "%A": "[Monday-Sunday]",
- "%p": "[AM|PM]",
- "%z": "[+HHMM|-HHMM]"
- }
- for key, val in mapping.items():
- format_string = format_string.replace(key, val)
- return format_string
-
-
-def strip_multiple_choice_msg(help_text):
+def set_value(dictionary, keys, value):
"""
- Remove the 'Hold down "control" ...' message that is Django enforces in
- select multiple fields on ModelForms. (Required for 1.5 and earlier)
+ Similar to Python's built in `dictionary[key] = value`,
+ but takes a list of nested keys instead of a single key.
- See https://code.djangoproject.com/ticket/9321
+ set_value({'a': 1}, [], {'b': 2}) -> {'a': 1, 'b': 2}
+ set_value({'a': 1}, ['x'], 2) -> {'a': 1, 'x': 2}
+ set_value({'a': 1}, ['x', 'y'], 2) -> {'a': 1, 'x': {'y': 2}}
"""
- multiple_choice_msg = _(' Hold down "Control", or "Command" on a Mac, to select more than one.')
- multiple_choice_msg = force_text(multiple_choice_msg)
+ if not keys:
+ dictionary.update(value)
+ return
- return help_text.replace(multiple_choice_msg, '')
+ for key in keys[:-1]:
+ if key not in dictionary:
+ dictionary[key] = {}
+ dictionary = dictionary[key]
+
+ dictionary[keys[-1]] = value
+
+
+class ValidationError(Exception):
+ pass
+
+
+class SkipField(Exception):
+ pass
class Field(object):
- read_only = True
- creation_counter = 0
- empty = ''
- type_name = None
- partial = False
- use_files = False
- form_field_class = forms.CharField
- type_label = 'field'
- widget = None
+ _creation_counter = 0
- def __init__(self, source=None, label=None, help_text=None):
- self.parent = None
-
- self.creation_counter = Field.creation_counter
- Field.creation_counter += 1
-
- self.source = source
-
- if label is not None:
- self.label = smart_text(label)
- else:
- self.label = None
-
- if help_text is not None:
- self.help_text = strip_multiple_choice_msg(smart_text(help_text))
- else:
- self.help_text = None
-
- self._errors = []
- self._value = None
- self._name = None
-
- @property
- def errors(self):
- return self._errors
-
- def widget_html(self):
- if not self.widget:
- return ''
-
- attrs = {}
- if 'id' not in self.widget.attrs:
- attrs['id'] = self._name
-
- return self.widget.render(self._name, self._value, attrs=attrs)
-
- def label_tag(self):
- return '' % (self._name, self.label)
-
- def initialize(self, parent, field_name):
- """
- Called to set up a field prior to field_to_native or field_from_native.
-
- parent - The parent serializer.
- field_name - The name of the field being initialized.
- """
- self.parent = parent
- self.root = parent.root or parent
- self.context = self.root.context
- self.partial = self.root.partial
- if self.partial:
- self.required = False
-
- def field_from_native(self, data, files, field_name, into):
- """
- Given a dictionary and a field name, updates the dictionary `into`,
- with the field and it's deserialized value.
- """
- return
-
- def field_to_native(self, obj, field_name):
- """
- Given an object and a field name, returns the value that should be
- serialized for that field.
- """
- if obj is None:
- return self.empty
-
- if self.source == '*':
- return self.to_native(obj)
-
- source = self.source or field_name
- value = obj
-
- for component in source.split('.'):
- value = get_component(value, component)
- if value is None:
- break
-
- return self.to_native(value)
-
- def to_native(self, value):
- """
- Converts the field's value into it's simple representation.
- """
- if is_simple_callable(value):
- value = value()
-
- if is_protected_type(value):
- return value
- elif (is_non_str_iterable(value) and
- not isinstance(value, (dict, six.string_types))):
- return [self.to_native(item) for item in value]
- elif isinstance(value, dict):
- # Make sure we preserve field ordering, if it exists
- ret = SortedDict()
- for key, val in value.items():
- ret[key] = self.to_native(val)
- return ret
- return force_text(value)
-
- def attributes(self):
- """
- Returns a dictionary of attributes to be used when serializing to xml.
- """
- if self.type_name:
- return {'type': self.type_name}
- return {}
-
- def metadata(self):
- metadata = SortedDict()
- metadata['type'] = self.type_label
- metadata['required'] = getattr(self, 'required', False)
- optional_attrs = ['read_only', 'label', 'help_text',
- 'min_length', 'max_length']
- for attr in optional_attrs:
- value = getattr(self, attr, None)
- if value is not None and value != '':
- metadata[attr] = force_text(value, strings_only=True)
- return metadata
-
-
-class WritableField(Field):
- """
- Base for read/write fields.
- """
- write_only = False
- default_validators = []
- default_error_messages = {
- 'required': _('This field is required.'),
- 'invalid': _('Invalid value.'),
+ MESSAGES = {
+ 'required': 'This field is required.'
}
- widget = widgets.TextInput
- default = None
- def __init__(self, source=None, label=None, help_text=None,
- read_only=False, write_only=False, required=None,
- validators=[], error_messages=None, widget=None,
- default=None, blank=None):
+ _NOT_READ_ONLY_WRITE_ONLY = 'May not set both `read_only` and `write_only`'
+ _NOT_READ_ONLY_REQUIRED = 'May not set both `read_only` and `required`'
+ _NOT_READ_ONLY_DEFAULT = 'May not set both `read_only` and `default`'
+ _NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`'
+ _MISSING_ERROR_MESSAGE = (
+ 'ValidationError raised by `{class_name}`, but error key `{key}` does '
+ 'not exist in the `MESSAGES` dictionary.'
+ )
- super(WritableField, self).__init__(source=source, label=label, help_text=help_text)
+ def __init__(self, read_only=False, write_only=False,
+ required=None, default=empty, initial=None, source=None,
+ label=None, style=None):
+ self._creation_counter = Field._creation_counter
+ Field._creation_counter += 1
+
+ # If `required` is unset, then use `True` unless a default is provided.
+ if required is None:
+ required = default is empty and not read_only
+
+ # Some combinations of keyword arguments do not make sense.
+ assert not (read_only and write_only), self._NOT_READ_ONLY_WRITE_ONLY
+ assert not (read_only and required), self._NOT_READ_ONLY_REQUIRED
+ assert not (read_only and default is not empty), self._NOT_READ_ONLY_DEFAULT
+ assert not (required and default is not empty), self._NOT_REQUIRED_DEFAULT
self.read_only = read_only
self.write_only = write_only
+ self.required = required
+ self.default = default
+ self.source = source
+ self.initial = initial
+ self.label = label
+ self.style = {} if style is None else style
- assert not (read_only and write_only), "Cannot set read_only=True and write_only=True"
+ def bind(self, field_name, parent, root):
+ """
+ Setup the context for the field instance.
+ """
+ self.field_name = field_name
+ self.parent = parent
+ self.root = root
- if required is None:
- self.required = not(read_only)
+ # `self.label` should deafult to being based on the field name.
+ if self.label is None:
+ self.label = self.field_name.replace('_', ' ').capitalize()
+
+ # self.source should default to being the same as the field name.
+ if self.source is None:
+ self.source = field_name
+
+ # self.source_attrs is a list of attributes that need to be looked up
+ # when serializing the instance, or populating the validated data.
+ if self.source == '*':
+ self.source_attrs = []
else:
- assert not (read_only and required), "Cannot set required=True and read_only=True"
- self.required = required
+ self.source_attrs = self.source.split('.')
- messages = {}
- for c in reversed(self.__class__.__mro__):
- messages.update(getattr(c, 'default_error_messages', {}))
- messages.update(error_messages or {})
- self.error_messages = messages
+ def get_initial(self):
+ """
+ Return a value to use when the field is being returned as a primative
+ value, without any object instance.
+ """
+ return self.initial
- self.validators = self.default_validators + validators
- self.default = default if default is not None else self.default
+ def get_value(self, dictionary):
+ """
+ Given the *incoming* primative data, return the value for this field
+ that should be validated and transformed to a native value.
+ """
+ return dictionary.get(self.field_name, empty)
- # Widgets are only used for HTML forms.
- widget = widget or self.widget
- if isinstance(widget, type):
- widget = widget()
- self.widget = widget
+ def get_attribute(self, instance):
+ """
+ Given the *outgoing* object instance, return the value for this field
+ that should be returned as a primative value.
+ """
+ return get_attribute(instance, self.source_attrs)
- def __deepcopy__(self, memo):
- result = copy.copy(self)
- memo[id(self)] = result
- result.validators = self.validators[:]
- return result
+ def get_default(self):
+ """
+ Return the default value to use when validating data if no input
+ is provided for this field.
- def get_default_value(self):
- if is_simple_callable(self.default):
- return self.default()
+ If a default has not been set for this field then this will simply
+ return `empty`, indicating that no value should be set in the
+ validated data for this field.
+ """
+ if self.default is empty:
+ raise SkipField()
return self.default
- def validate(self, value):
- if value in validators.EMPTY_VALUES and self.required:
- raise ValidationError(self.error_messages['required'])
-
- def run_validators(self, value):
- if value in validators.EMPTY_VALUES:
- return
- errors = []
- for v in self.validators:
- try:
- v(value)
- except ValidationError as e:
- if hasattr(e, 'code') and e.code in self.error_messages:
- message = self.error_messages[e.code]
- if e.params:
- message = message % e.params
- errors.append(message)
- else:
- errors.extend(e.messages)
- if errors:
- raise ValidationError(errors)
-
- def field_to_native(self, obj, field_name):
- if self.write_only:
- return None
- return super(WritableField, self).field_to_native(obj, field_name)
-
- def field_from_native(self, data, files, field_name, into):
+ def validate(self, data=empty):
"""
- Given a dictionary and a field name, updates the dictionary `into`,
- with the field and it's deserialized value.
+ Validate a simple representation and return the internal value.
+
+ The provided data may be `empty` if no representation was included.
+ May return `empty` if the field should not be included in the
+ validated data.
"""
- if self.read_only:
- return
+ if data is empty:
+ if self.required:
+ self.fail('required')
+ return self.get_default()
- try:
- data = data or {}
- if self.use_files:
- files = files or {}
- try:
- native = files[field_name]
- except KeyError:
- native = data[field_name]
- else:
- native = data[field_name]
- except KeyError:
- if self.default is not None and not self.partial:
- # Note: partial updates shouldn't set defaults
- native = self.get_default_value()
- else:
- if self.required:
- raise ValidationError(self.error_messages['required'])
- return
+ return self.to_native(data)
- value = self.from_native(native)
- if self.source == '*':
- if value:
- into.update(value)
- else:
- self.validate(value)
- self.run_validators(value)
- into[self.source or field_name] = value
-
- def from_native(self, value):
+ def to_native(self, data):
"""
- Reverts a simple representation back to the field's value.
+ Transform the *incoming* primative data into a native value.
+ """
+ return data
+
+ def to_primative(self, value):
+ """
+ Transform the *outgoing* native value into primative data.
"""
return value
-
-class ModelField(WritableField):
- """
- A generic field that can be used against an arbitrary model field.
- """
- def __init__(self, *args, **kwargs):
+ def fail(self, key, **kwargs):
+ """
+ A helper method that simply raises a validation error.
+ """
try:
- self.model_field = kwargs.pop('model_field')
+ raise ValidationError(self.MESSAGES[key].format(**kwargs))
except KeyError:
- raise ValueError("ModelField requires 'model_field' kwarg")
+ class_name = self.__class__.__name__
+ msg = self._MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key)
+ raise AssertionError(msg)
- self.min_length = kwargs.pop('min_length',
- getattr(self.model_field, 'min_length', None))
- self.max_length = kwargs.pop('max_length',
- getattr(self.model_field, 'max_length', None))
- self.min_value = kwargs.pop('min_value',
- getattr(self.model_field, 'min_value', None))
- self.max_value = kwargs.pop('max_value',
- getattr(self.model_field, 'max_value', None))
- super(ModelField, self).__init__(*args, **kwargs)
+class BooleanField(Field):
+ MESSAGES = {
+ 'required': 'This field is required.',
+ 'invalid_value': '`{input}` is not a valid boolean.'
+ }
+ TRUE_VALUES = {'t', 'T', 'true', 'True', 'TRUE', '1', 1, True}
+ FALSE_VALUES = {'f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False}
- if self.min_length is not None:
- self.validators.append(validators.MinLengthValidator(self.min_length))
- if self.max_length is not None:
- self.validators.append(validators.MaxLengthValidator(self.max_length))
- if self.min_value is not None:
- self.validators.append(validators.MinValueValidator(self.min_value))
- if self.max_value is not None:
- self.validators.append(validators.MaxValueValidator(self.max_value))
+ def get_value(self, dictionary):
+ if html.is_html_input(dictionary):
+ # HTML forms do not send a `False` value on an empty checkbox,
+ # so we override the default empty value to be False.
+ return dictionary.get(self.field_name, False)
+ return dictionary.get(self.field_name, empty)
- def from_native(self, value):
- rel = getattr(self.model_field, "rel", None)
- if rel is not None:
- return rel.to._meta.get_field(rel.field_name).to_python(value)
+ def to_native(self, data):
+ if data in self.TRUE_VALUES:
+ return True
+ elif data in self.FALSE_VALUES:
+ return False
+ self.fail('invalid_value', input=data)
+
+
+class CharField(Field):
+ MESSAGES = {
+ 'required': 'This field is required.',
+ 'blank': 'This field may not be blank.'
+ }
+
+ def __init__(self, *args, **kwargs):
+ self.allow_blank = kwargs.pop('allow_blank', False)
+ super(CharField, self).__init__(*args, **kwargs)
+
+ def to_native(self, data):
+ if data == '' and not self.allow_blank:
+ self.fail('blank')
+ return str(data)
+
+
+class ChoiceField(Field):
+ MESSAGES = {
+ 'required': 'This field is required.',
+ 'invalid_choice': '`{input}` is not a valid choice.'
+ }
+ coerce_to_type = str
+
+ def __init__(self, *args, **kwargs):
+ choices = kwargs.pop('choices')
+
+ assert choices, '`choices` argument is required and may not be empty'
+
+ # Allow either single or paired choices style:
+ # choices = [1, 2, 3]
+ # choices = [(1, 'First'), (2, 'Second'), (3, 'Third')]
+ pairs = [
+ isinstance(item, (list, tuple)) and len(item) == 2
+ for item in choices
+ ]
+ if all(pairs):
+ self.choices = {key: val for key, val in choices}
else:
- return self.model_field.to_python(value)
+ self.choices = {item: item for item in choices}
- def field_to_native(self, obj, field_name):
- value = self.model_field._get_val_from_obj(obj)
- if is_protected_type(value):
- return value
- return self.model_field.value_to_string(obj)
-
- def attributes(self):
- return {
- "type": self.model_field.get_internal_type()
+ # Map the string representation of choices to the underlying value.
+ # Allows us to deal with eg. integer choices while supporting either
+ # integer or string input, but still get the correct datatype out.
+ self.choice_strings_to_values = {
+ str(key): key for key in self.choices.keys()
}
-
-# Typed Fields
-
-class BooleanField(WritableField):
- type_name = 'BooleanField'
- type_label = 'boolean'
- form_field_class = forms.BooleanField
- widget = widgets.CheckboxInput
- default_error_messages = {
- 'invalid': _("'%s' value must be either True or False."),
- }
- empty = False
-
- def field_from_native(self, data, files, field_name, into):
- # HTML checkboxes do not explicitly represent unchecked as `False`
- # we deal with that here...
- if isinstance(data, QueryDict) and self.default is None:
- self.default = False
-
- return super(BooleanField, self).field_from_native(
- data, files, field_name, into
- )
-
- def from_native(self, value):
- if value in ('true', 't', 'True', '1'):
- return True
- if value in ('false', 'f', 'False', '0'):
- return False
- return bool(value)
-
-
-class CharField(WritableField):
- type_name = 'CharField'
- type_label = 'string'
- form_field_class = forms.CharField
-
- def __init__(self, max_length=None, min_length=None, allow_none=False, *args, **kwargs):
- self.max_length, self.min_length = max_length, min_length
- self.allow_none = allow_none
- super(CharField, self).__init__(*args, **kwargs)
- if min_length is not None:
- self.validators.append(validators.MinLengthValidator(min_length))
- if max_length is not None:
- self.validators.append(validators.MaxLengthValidator(max_length))
-
- def from_native(self, value):
- if isinstance(value, six.string_types):
- return value
-
- if value is None and not self.allow_none:
- return ''
-
- return smart_text(value)
-
-
-class URLField(CharField):
- type_name = 'URLField'
- type_label = 'url'
-
- def __init__(self, **kwargs):
- if 'validators' not in kwargs:
- kwargs['validators'] = [validators.URLValidator()]
- super(URLField, self).__init__(**kwargs)
-
-
-class SlugField(CharField):
- type_name = 'SlugField'
- type_label = 'slug'
- form_field_class = forms.SlugField
-
- default_error_messages = {
- 'invalid': _("Enter a valid 'slug' consisting of letters, numbers,"
- " underscores or hyphens."),
- }
- default_validators = [validators.validate_slug]
-
- def __init__(self, *args, **kwargs):
- super(SlugField, self).__init__(*args, **kwargs)
-
-
-class ChoiceField(WritableField):
- type_name = 'ChoiceField'
- type_label = 'choice'
- form_field_class = forms.ChoiceField
- widget = widgets.Select
- default_error_messages = {
- 'invalid_choice': _('Select a valid choice. %(value)s is not one of '
- 'the available choices.'),
- }
-
- def __init__(self, choices=(), blank_display_value=None, *args, **kwargs):
- self.empty = kwargs.pop('empty', '')
super(ChoiceField, self).__init__(*args, **kwargs)
- self.choices = choices
- if not self.required:
- if blank_display_value is None:
- blank_choice = BLANK_CHOICE_DASH
- else:
- blank_choice = [('', blank_display_value)]
- self.choices = blank_choice + self.choices
-
- def _get_choices(self):
- return self._choices
-
- def _set_choices(self, value):
- # Setting choices also sets the choices on the widget.
- # choices can be any iterable, but we call list() on it because
- # it will be consumed more than once.
- self._choices = self.widget.choices = list(value)
-
- choices = property(_get_choices, _set_choices)
-
- def metadata(self):
- data = super(ChoiceField, self).metadata()
- data['choices'] = [{'value': v, 'display_name': n} for v, n in self.choices]
- return data
-
- def validate(self, value):
- """
- Validates that the input is in self.choices.
- """
- super(ChoiceField, self).validate(value)
- if value and not self.valid_value(value):
- raise ValidationError(self.error_messages['invalid_choice'] % {'value': value})
-
- def valid_value(self, value):
- """
- Check to see if the provided value is a valid choice.
- """
- for k, v in self.choices:
- if isinstance(v, (list, tuple)):
- # This is an optgroup, so look inside the group for options
- for k2, v2 in v:
- if value == smart_text(k2):
- return True
- else:
- if value == smart_text(k) or value == k:
- return True
- return False
-
- def from_native(self, value):
- value = super(ChoiceField, self).from_native(value)
- if value == self.empty or value in validators.EMPTY_VALUES:
- return self.empty
- return value
-
-
-class EmailField(CharField):
- type_name = 'EmailField'
- type_label = 'email'
- form_field_class = forms.EmailField
-
- default_error_messages = {
- 'invalid': _('Enter a valid email address.'),
- }
- default_validators = [validators.validate_email]
-
- def from_native(self, value):
- ret = super(EmailField, self).from_native(value)
- if ret is None:
- return None
- return ret.strip()
-
-
-class RegexField(CharField):
- type_name = 'RegexField'
- type_label = 'regex'
- form_field_class = forms.RegexField
-
- def __init__(self, regex, max_length=None, min_length=None, *args, **kwargs):
- super(RegexField, self).__init__(max_length, min_length, *args, **kwargs)
- self.regex = regex
-
- def _get_regex(self):
- return self._regex
-
- def _set_regex(self, regex):
- if isinstance(regex, six.string_types):
- regex = re.compile(regex)
- self._regex = regex
- if hasattr(self, '_regex_validator') and self._regex_validator in self.validators:
- self.validators.remove(self._regex_validator)
- self._regex_validator = validators.RegexValidator(regex=regex)
- self.validators.append(self._regex_validator)
-
- regex = property(_get_regex, _set_regex)
-
-
-class DateField(WritableField):
- type_name = 'DateField'
- type_label = 'date'
- widget = widgets.DateInput
- form_field_class = forms.DateField
-
- default_error_messages = {
- 'invalid': _("Date has wrong format. Use one of these formats instead: %s"),
- }
- empty = None
- input_formats = api_settings.DATE_INPUT_FORMATS
- format = api_settings.DATE_FORMAT
-
- def __init__(self, input_formats=None, format=None, *args, **kwargs):
- self.input_formats = input_formats if input_formats is not None else self.input_formats
- self.format = format if format is not None else self.format
- super(DateField, self).__init__(*args, **kwargs)
-
- def from_native(self, value):
- if value in validators.EMPTY_VALUES:
- return None
-
- if isinstance(value, datetime.datetime):
- if timezone and settings.USE_TZ and timezone.is_aware(value):
- # Convert aware datetimes to the default time zone
- # before casting them to dates (#17742).
- default_timezone = timezone.get_default_timezone()
- value = timezone.make_naive(value, default_timezone)
- return value.date()
- if isinstance(value, datetime.date):
- return value
-
- for format in self.input_formats:
- if format.lower() == ISO_8601:
- try:
- parsed = parse_date(value)
- except (ValueError, TypeError):
- pass
- else:
- if parsed is not None:
- return parsed
- else:
- try:
- parsed = datetime.datetime.strptime(value, format)
- except (ValueError, TypeError):
- pass
- else:
- return parsed.date()
-
- msg = self.error_messages['invalid'] % readable_date_formats(self.input_formats)
- raise ValidationError(msg)
-
- def to_native(self, value):
- if value is None or self.format is None:
- return value
-
- if isinstance(value, datetime.datetime):
- value = value.date()
-
- if self.format.lower() == ISO_8601:
- return value.isoformat()
- return value.strftime(self.format)
-
-
-class DateTimeField(WritableField):
- type_name = 'DateTimeField'
- type_label = 'datetime'
- widget = widgets.DateTimeInput
- form_field_class = forms.DateTimeField
-
- default_error_messages = {
- 'invalid': _("Datetime has wrong format. Use one of these formats instead: %s"),
- }
- empty = None
- input_formats = api_settings.DATETIME_INPUT_FORMATS
- format = api_settings.DATETIME_FORMAT
-
- def __init__(self, input_formats=None, format=None, *args, **kwargs):
- self.input_formats = input_formats if input_formats is not None else self.input_formats
- self.format = format if format is not None else self.format
- super(DateTimeField, self).__init__(*args, **kwargs)
-
- def from_native(self, value):
- if value in validators.EMPTY_VALUES:
- return None
-
- if isinstance(value, datetime.datetime):
- return value
- if isinstance(value, datetime.date):
- value = datetime.datetime(value.year, value.month, value.day)
- if settings.USE_TZ:
- # For backwards compatibility, interpret naive datetimes in
- # local time. This won't work during DST change, but we can't
- # do much about it, so we let the exceptions percolate up the
- # call stack.
- warnings.warn("DateTimeField received a naive datetime (%s)"
- " while time zone support is active." % value,
- RuntimeWarning)
- default_timezone = timezone.get_default_timezone()
- value = timezone.make_aware(value, default_timezone)
- return value
-
- for format in self.input_formats:
- if format.lower() == ISO_8601:
- try:
- parsed = parse_datetime(value)
- except (ValueError, TypeError):
- pass
- else:
- if parsed is not None:
- return parsed
- else:
- try:
- parsed = datetime.datetime.strptime(value, format)
- except (ValueError, TypeError):
- pass
- else:
- return parsed
-
- msg = self.error_messages['invalid'] % readable_datetime_formats(self.input_formats)
- raise ValidationError(msg)
-
- def to_native(self, value):
- if value is None or self.format is None:
- return value
-
- if self.format.lower() == ISO_8601:
- ret = value.isoformat()
- if ret.endswith('+00:00'):
- ret = ret[:-6] + 'Z'
- return ret
- return value.strftime(self.format)
-
-
-class TimeField(WritableField):
- type_name = 'TimeField'
- type_label = 'time'
- widget = widgets.TimeInput
- form_field_class = forms.TimeField
-
- default_error_messages = {
- 'invalid': _("Time has wrong format. Use one of these formats instead: %s"),
- }
- empty = None
- input_formats = api_settings.TIME_INPUT_FORMATS
- format = api_settings.TIME_FORMAT
-
- def __init__(self, input_formats=None, format=None, *args, **kwargs):
- self.input_formats = input_formats if input_formats is not None else self.input_formats
- self.format = format if format is not None else self.format
- super(TimeField, self).__init__(*args, **kwargs)
-
- def from_native(self, value):
- if value in validators.EMPTY_VALUES:
- return None
-
- if isinstance(value, datetime.time):
- return value
-
- for format in self.input_formats:
- if format.lower() == ISO_8601:
- try:
- parsed = parse_time(value)
- except (ValueError, TypeError):
- pass
- else:
- if parsed is not None:
- return parsed
- else:
- try:
- parsed = datetime.datetime.strptime(value, format)
- except (ValueError, TypeError):
- pass
- else:
- return parsed.time()
-
- msg = self.error_messages['invalid'] % readable_time_formats(self.input_formats)
- raise ValidationError(msg)
-
- def to_native(self, value):
- if value is None or self.format is None:
- return value
-
- if isinstance(value, datetime.datetime):
- value = value.time()
-
- if self.format.lower() == ISO_8601:
- return value.isoformat()
- return value.strftime(self.format)
-
-
-class IntegerField(WritableField):
- type_name = 'IntegerField'
- type_label = 'integer'
- form_field_class = forms.IntegerField
- empty = 0
-
- default_error_messages = {
- 'invalid': _('Enter a whole number.'),
- 'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'),
- 'min_value': _('Ensure this value is greater than or equal to %(limit_value)s.'),
- }
-
- def __init__(self, max_value=None, min_value=None, *args, **kwargs):
- self.max_value, self.min_value = max_value, min_value
- super(IntegerField, self).__init__(*args, **kwargs)
-
- if max_value is not None:
- self.validators.append(validators.MaxValueValidator(max_value))
- if min_value is not None:
- self.validators.append(validators.MinValueValidator(min_value))
-
- def from_native(self, value):
- if value in validators.EMPTY_VALUES:
- return None
+ def to_native(self, data):
try:
- value = int(str(value))
+ return self.choice_strings_to_values[str(data)]
+ except KeyError:
+ self.fail('invalid_choice', input=data)
+
+
+class MultipleChoiceField(ChoiceField):
+ MESSAGES = {
+ 'required': 'This field is required.',
+ 'invalid_choice': '`{input}` is not a valid choice.',
+ 'not_a_list': 'Expected a list of items but got type `{input_type}`'
+ }
+
+ def to_native(self, data):
+ if not hasattr(data, '__iter__'):
+ self.fail('not_a_list', input_type=type(data).__name__)
+ return set([
+ super(MultipleChoiceField, self).to_native(item)
+ for item in data
+ ])
+
+
+class IntegerField(Field):
+ MESSAGES = {
+ 'required': 'This field is required.',
+ 'invalid_integer': 'A valid integer is required.'
+ }
+
+ def to_native(self, data):
+ try:
+ data = int(str(data))
except (ValueError, TypeError):
- raise ValidationError(self.error_messages['invalid'])
- return value
-
-
-class FloatField(WritableField):
- type_name = 'FloatField'
- type_label = 'float'
- form_field_class = forms.FloatField
- empty = 0
-
- default_error_messages = {
- 'invalid': _("'%s' value must be a float."),
- }
-
- def from_native(self, value):
- if value in validators.EMPTY_VALUES:
- return None
-
- try:
- return float(value)
- except (TypeError, ValueError):
- msg = self.error_messages['invalid'] % value
- raise ValidationError(msg)
-
-
-class DecimalField(WritableField):
- type_name = 'DecimalField'
- type_label = 'decimal'
- form_field_class = forms.DecimalField
- empty = Decimal('0')
-
- default_error_messages = {
- 'invalid': _('Enter a number.'),
- 'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'),
- 'min_value': _('Ensure this value is greater than or equal to %(limit_value)s.'),
- 'max_digits': _('Ensure that there are no more than %s digits in total.'),
- 'max_decimal_places': _('Ensure that there are no more than %s decimal places.'),
- 'max_whole_digits': _('Ensure that there are no more than %s digits before the decimal point.')
- }
-
- def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, *args, **kwargs):
- self.max_value, self.min_value = max_value, min_value
- self.max_digits, self.decimal_places = max_digits, decimal_places
- super(DecimalField, self).__init__(*args, **kwargs)
-
- if max_value is not None:
- self.validators.append(validators.MaxValueValidator(max_value))
- if min_value is not None:
- self.validators.append(validators.MinValueValidator(min_value))
-
- def from_native(self, value):
- """
- Validates that the input is a decimal number. Returns a Decimal
- instance. Returns None for empty values. Ensures that there are no more
- than max_digits in the number, and no more than decimal_places digits
- after the decimal point.
- """
- if value in validators.EMPTY_VALUES:
- return None
- value = smart_text(value).strip()
- try:
- value = Decimal(value)
- except DecimalException:
- raise ValidationError(self.error_messages['invalid'])
- return value
-
- def validate(self, value):
- super(DecimalField, self).validate(value)
- if value in validators.EMPTY_VALUES:
- return
- # Check for NaN, Inf and -Inf values. We can't compare directly for NaN,
- # since it is never equal to itself. However, NaN is the only value that
- # isn't equal to itself, so we can use this to identify NaN
- if value != value or value == Decimal("Inf") or value == Decimal("-Inf"):
- raise ValidationError(self.error_messages['invalid'])
- sign, digittuple, exponent = value.as_tuple()
- decimals = abs(exponent)
- # digittuple doesn't include any leading zeros.
- digits = len(digittuple)
- if decimals > digits:
- # We have leading zeros up to or past the decimal point. Count
- # everything past the decimal point as a digit. We do not count
- # 0 before the decimal point as a digit since that would mean
- # we would not allow max_digits = decimal_places.
- digits = decimals
- whole_digits = digits - decimals
-
- if self.max_digits is not None and digits > self.max_digits:
- raise ValidationError(self.error_messages['max_digits'] % self.max_digits)
- if self.decimal_places is not None and decimals > self.decimal_places:
- raise ValidationError(self.error_messages['max_decimal_places'] % self.decimal_places)
- if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places):
- raise ValidationError(self.error_messages['max_whole_digits'] % (self.max_digits - self.decimal_places))
- return value
-
-
-class FileField(WritableField):
- use_files = True
- type_name = 'FileField'
- type_label = 'file upload'
- form_field_class = forms.FileField
- widget = widgets.FileInput
-
- default_error_messages = {
- 'invalid': _("No file was submitted. Check the encoding type on the form."),
- 'missing': _("No file was submitted."),
- 'empty': _("The submitted file is empty."),
- 'max_length': _('Ensure this filename has at most %(max)d characters (it has %(length)d).'),
- 'contradiction': _('Please either submit a file or check the clear checkbox, not both.')
- }
-
- def __init__(self, *args, **kwargs):
- self.max_length = kwargs.pop('max_length', None)
- self.allow_empty_file = kwargs.pop('allow_empty_file', False)
- super(FileField, self).__init__(*args, **kwargs)
-
- def from_native(self, data):
- if data in validators.EMPTY_VALUES:
- return None
-
- # UploadedFile objects should have name and size attributes.
- try:
- file_name = data.name
- file_size = data.size
- except AttributeError:
- raise ValidationError(self.error_messages['invalid'])
-
- if self.max_length is not None and len(file_name) > self.max_length:
- error_values = {'max': self.max_length, 'length': len(file_name)}
- raise ValidationError(self.error_messages['max_length'] % error_values)
- if not file_name:
- raise ValidationError(self.error_messages['invalid'])
- if not self.allow_empty_file and not file_size:
- raise ValidationError(self.error_messages['empty'])
-
+ self.fail('invalid_integer')
return data
- def to_native(self, value):
- return value.name
+class MethodField(Field):
+ def __init__(self, **kwargs):
+ kwargs['source'] = '*'
+ kwargs['read_only'] = True
+ super(MethodField, self).__init__(**kwargs)
-class ImageField(FileField):
- use_files = True
- type_name = 'ImageField'
- type_label = 'image upload'
- form_field_class = forms.ImageField
-
- default_error_messages = {
- 'invalid_image': _("Upload a valid image. The file you uploaded was "
- "either not an image or a corrupted image."),
- }
-
- def from_native(self, data):
- """
- Checks that the file-upload field data contains a valid image (GIF, JPG,
- PNG, possibly others -- whatever the Python Imaging Library supports).
- """
- f = super(ImageField, self).from_native(data)
- if f is None:
- return None
-
- from rest_framework.compat import Image
- assert Image is not None, 'Either Pillow or PIL must be installed for ImageField support.'
-
- # We need to get a file object for PIL. We might have a path or we might
- # have to read the data into memory.
- if hasattr(data, 'temporary_file_path'):
- file = data.temporary_file_path()
- else:
- if hasattr(data, 'read'):
- file = BytesIO(data.read())
- else:
- file = BytesIO(data['content'])
-
- try:
- # load() could spot a truncated JPEG, but it loads the entire
- # image in memory, which is a DoS vector. See #3848 and #18520.
- # verify() must be called immediately after the constructor.
- Image.open(file).verify()
- except ImportError:
- # Under PyPy, it is possible to import PIL. However, the underlying
- # _imaging C module isn't available, so an ImportError will be
- # raised. Catch and re-raise.
- raise
- except Exception: # Python Imaging Library doesn't recognize it as an image
- raise ValidationError(self.error_messages['invalid_image'])
- if hasattr(f, 'seek') and callable(f.seek):
- f.seek(0)
- return f
-
-
-class SerializerMethodField(Field):
- """
- A field that gets its value by calling a method on the serializer it's attached to.
- """
-
- def __init__(self, method_name, *args, **kwargs):
- self.method_name = method_name
- super(SerializerMethodField, self).__init__(*args, **kwargs)
-
- def field_to_native(self, obj, field_name):
- value = getattr(self.parent, self.method_name)(obj)
- return self.to_native(value)
+ def to_primative(self, value):
+ attr = 'get_{field_name}'.format(field_name=self.field_name)
+ method = getattr(self.parent, attr)
+ return method(value)
diff --git a/rest_framework/generics.py b/rest_framework/generics.py
index a6f686571..6705cbb2f 100644
--- a/rest_framework/generics.py
+++ b/rest_framework/generics.py
@@ -3,7 +3,7 @@ Generic views that provide commonly needed behaviour.
"""
from __future__ import unicode_literals
-from django.core.exceptions import ImproperlyConfigured, PermissionDenied
+from django.core.exceptions import PermissionDenied
from django.core.paginator import Paginator, InvalidPage
from django.http import Http404
from django.shortcuts import get_object_or_404 as _get_object_or_404
@@ -11,7 +11,6 @@ from django.utils.translation import ugettext as _
from rest_framework import views, mixins, exceptions
from rest_framework.request import clone_request
from rest_framework.settings import api_settings
-import warnings
def strict_positive_int(integer_string, cutoff=None):
@@ -51,11 +50,6 @@ class GenericAPIView(views.APIView):
queryset = None
serializer_class = None
- # This shortcut may be used instead of setting either or both
- # of the `queryset`/`serializer_class` attributes, although using
- # the explicit style is generally preferred.
- model = None
-
# If you want to use object lookups other than pk, set this attribute.
# For more complex lookup requirements override `get_object()`.
lookup_field = 'pk'
@@ -71,20 +65,10 @@ class GenericAPIView(views.APIView):
# The filter backend classes to use for queryset filtering
filter_backends = api_settings.DEFAULT_FILTER_BACKENDS
- # The following attributes may be subject to change,
+ # The following attribute may be subject to change,
# and should be considered private API.
- model_serializer_class = api_settings.DEFAULT_MODEL_SERIALIZER_CLASS
paginator_class = Paginator
- ######################################
- # These are pending deprecation...
-
- pk_url_kwarg = 'pk'
- slug_url_kwarg = 'slug'
- slug_field = 'slug'
- allow_empty = True
- filter_backend = api_settings.FILTER_BACKEND
-
def get_serializer_context(self):
"""
Extra context provided to the serializer class.
@@ -95,18 +79,16 @@ class GenericAPIView(views.APIView):
'view': self
}
- def get_serializer(self, instance=None, data=None, files=None, many=False,
- partial=False, allow_add_remove=False):
+ def get_serializer(self, instance=None, data=None, many=False, partial=False):
"""
Return the serializer instance that should be used for validating and
deserializing input, and for serializing output.
"""
serializer_class = self.get_serializer_class()
context = self.get_serializer_context()
- return serializer_class(instance, data=data, files=files,
- many=many, partial=partial,
- allow_add_remove=allow_add_remove,
- context=context)
+ return serializer_class(
+ instance, data=data, many=many, partial=partial, context=context
+ )
def get_pagination_serializer(self, page):
"""
@@ -120,37 +102,16 @@ class GenericAPIView(views.APIView):
context = self.get_serializer_context()
return pagination_serializer_class(instance=page, context=context)
- def paginate_queryset(self, queryset, page_size=None):
+ def paginate_queryset(self, queryset):
"""
Paginate a queryset if required, either returning a page object,
or `None` if pagination is not configured for this view.
"""
- deprecated_style = False
- if page_size is not None:
- warnings.warn('The `page_size` parameter to `paginate_queryset()` '
- 'is deprecated. '
- 'Note that the return style of this method is also '
- 'changed, and will simply return a page object '
- 'when called without a `page_size` argument.',
- DeprecationWarning, stacklevel=2)
- deprecated_style = True
- else:
- # Determine the required page size.
- # If pagination is not configured, simply return None.
- page_size = self.get_paginate_by()
- if not page_size:
- return None
+ page_size = self.get_paginate_by()
+ if not page_size:
+ return None
- if not self.allow_empty:
- warnings.warn(
- 'The `allow_empty` parameter is deprecated. '
- 'To use `allow_empty=False` style behavior, You should override '
- '`get_queryset()` and explicitly raise a 404 on empty querysets.',
- DeprecationWarning, stacklevel=2
- )
-
- paginator = self.paginator_class(queryset, page_size,
- allow_empty_first_page=self.allow_empty)
+ paginator = self.paginator_class(queryset, page_size)
page_kwarg = self.kwargs.get(self.page_kwarg)
page_query_param = self.request.QUERY_PARAMS.get(self.page_kwarg)
page = page_kwarg or page_query_param or 1
@@ -170,8 +131,6 @@ class GenericAPIView(views.APIView):
'message': str(exc)
})
- if deprecated_style:
- return (paginator, page, page.object_list, page.has_other_pages())
return page
def filter_queryset(self, queryset):
@@ -191,29 +150,12 @@ class GenericAPIView(views.APIView):
"""
Returns the list of filter backends that this view requires.
"""
- if self.filter_backends is None:
- filter_backends = []
- else:
- # Note that we are returning a *copy* of the class attribute,
- # so that it is safe for the view to mutate it if needed.
- filter_backends = list(self.filter_backends)
-
- if not filter_backends and self.filter_backend:
- warnings.warn(
- 'The `filter_backend` attribute and `FILTER_BACKEND` setting '
- 'are deprecated in favor of a `filter_backends` '
- 'attribute and `DEFAULT_FILTER_BACKENDS` setting, that take '
- 'a *list* of filter backend classes.',
- DeprecationWarning, stacklevel=2
- )
- filter_backends = [self.filter_backend]
-
- return filter_backends
+ return list(self.filter_backends)
# The following methods provide default implementations
# that you may want to override for more complex cases.
- def get_paginate_by(self, queryset=None):
+ def get_paginate_by(self):
"""
Return the size of pages to use with pagination.
@@ -222,11 +164,6 @@ class GenericAPIView(views.APIView):
Otherwise defaults to using `self.paginate_by`.
"""
- if queryset is not None:
- warnings.warn('The `queryset` parameter to `get_paginate_by()` '
- 'is deprecated.',
- DeprecationWarning, stacklevel=2)
-
if self.paginate_by_param:
try:
return strict_positive_int(
@@ -248,26 +185,13 @@ class GenericAPIView(views.APIView):
(Eg. admins get full serialization, others get basic serialization)
"""
- serializer_class = self.serializer_class
- if serializer_class is not None:
- return serializer_class
-
- warnings.warn(
- 'The `.model` attribute on view classes is now deprecated in favor '
- 'of the more explicit `serializer_class` and `queryset` attributes.',
- DeprecationWarning, stacklevel=2
+ assert self.serializer_class is not None, (
+ "'%s' should either include a `serializer_class` attribute, "
+ "or override the `get_serializer_class()` method."
+ % self.__class__.__name__
)
- assert self.model is not None, \
- "'%s' should either include a 'serializer_class' attribute, " \
- "or use the 'model' attribute as a shortcut for " \
- "automatically generating a serializer class." \
- % self.__class__.__name__
-
- class DefaultSerializer(self.model_serializer_class):
- class Meta:
- model = self.model
- return DefaultSerializer
+ return self.serializer_class
def get_queryset(self):
"""
@@ -284,21 +208,15 @@ class GenericAPIView(views.APIView):
(Eg. return a list of items that is specific to the user)
"""
- if self.queryset is not None:
- return self.queryset._clone()
+ assert self.queryset is not None, (
+ "'%s' should either include a `queryset` attribute, "
+ "or override the `get_queryset()` method."
+ % self.__class__.__name__
+ )
- if self.model is not None:
- warnings.warn(
- 'The `.model` attribute on view classes is now deprecated in favor '
- 'of the more explicit `serializer_class` and `queryset` attributes.',
- DeprecationWarning, stacklevel=2
- )
- return self.model._default_manager.all()
+ return self.queryset._clone()
- error_format = "'%s' must define 'queryset' or 'model'"
- raise ImproperlyConfigured(error_format % self.__class__.__name__)
-
- def get_object(self, queryset=None):
+ def get_object(self):
"""
Returns the object the view is displaying.
@@ -306,43 +224,19 @@ class GenericAPIView(views.APIView):
queryset lookups. Eg if objects are referenced using multiple
keyword arguments in the url conf.
"""
- # Determine the base queryset to use.
- if queryset is None:
- queryset = self.filter_queryset(self.get_queryset())
- else:
- pass # Deprecation warning
+ queryset = self.filter_queryset(self.get_queryset())
# Perform the lookup filtering.
- # Note that `pk` and `slug` are deprecated styles of lookup filtering.
lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
- lookup = self.kwargs.get(lookup_url_kwarg, None)
- pk = self.kwargs.get(self.pk_url_kwarg, None)
- slug = self.kwargs.get(self.slug_url_kwarg, None)
- if lookup is not None:
- filter_kwargs = {self.lookup_field: lookup}
- elif pk is not None and self.lookup_field == 'pk':
- warnings.warn(
- 'The `pk_url_kwarg` attribute is deprecated. '
- 'Use the `lookup_field` attribute instead',
- DeprecationWarning
- )
- filter_kwargs = {'pk': pk}
- elif slug is not None and self.lookup_field == 'pk':
- warnings.warn(
- 'The `slug_url_kwarg` attribute is deprecated. '
- 'Use the `lookup_field` attribute instead',
- DeprecationWarning
- )
- filter_kwargs = {self.slug_field: slug}
- else:
- raise ImproperlyConfigured(
- 'Expected view %s to be called with a URL keyword argument '
- 'named "%s". Fix your URL conf, or set the `.lookup_field` '
- 'attribute on the view correctly.' %
- (self.__class__.__name__, self.lookup_field)
- )
+ assert lookup_url_kwarg in self.kwargs, (
+ 'Expected view %s to be called with a URL keyword argument '
+ 'named "%s". Fix your URL conf, or set the `.lookup_field` '
+ 'attribute on the view correctly.' %
+ (self.__class__.__name__, lookup_url_kwarg)
+ )
+ filter_kwargs = {self.lookup_field: self.kwargs[lookup_url_kwarg]}
obj = get_object_or_404(queryset, **filter_kwargs)
# May raise a permission denied
@@ -540,25 +434,3 @@ class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin,
def delete(self, request, *args, **kwargs):
return self.destroy(request, *args, **kwargs)
-
-
-# Deprecated classes
-
-class MultipleObjectAPIView(GenericAPIView):
- def __init__(self, *args, **kwargs):
- warnings.warn(
- 'Subclassing `MultipleObjectAPIView` is deprecated. '
- 'You should simply subclass `GenericAPIView` instead.',
- DeprecationWarning, stacklevel=2
- )
- super(MultipleObjectAPIView, self).__init__(*args, **kwargs)
-
-
-class SingleObjectAPIView(GenericAPIView):
- def __init__(self, *args, **kwargs):
- warnings.warn(
- 'Subclassing `SingleObjectAPIView` is deprecated. '
- 'You should simply subclass `GenericAPIView` instead.',
- DeprecationWarning, stacklevel=2
- )
- super(SingleObjectAPIView, self).__init__(*args, **kwargs)
diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py
index 2cc87eef1..ee01cabc7 100644
--- a/rest_framework/mixins.py
+++ b/rest_framework/mixins.py
@@ -12,10 +12,9 @@ from rest_framework import status
from rest_framework.response import Response
from rest_framework.request import clone_request
from rest_framework.settings import api_settings
-import warnings
-def _get_validation_exclusions(obj, pk=None, slug_field=None, lookup_field=None):
+def _get_validation_exclusions(obj, lookup_field=None):
"""
Given a model instance, and an optional pk and slug field,
return the full list of all other field names on that model.
@@ -23,23 +22,13 @@ def _get_validation_exclusions(obj, pk=None, slug_field=None, lookup_field=None)
For use when performing full_clean on a model instance,
so we only clean the required fields.
"""
- include = []
-
- if pk:
- # Deprecated
+ if lookup_field == 'pk':
pk_field = obj._meta.pk
while pk_field.rel:
pk_field = pk_field.rel.to._meta.pk
- include.append(pk_field.name)
+ lookup_field = pk_field.name
- if slug_field:
- # Deprecated
- include.append(slug_field)
-
- if lookup_field and lookup_field != 'pk':
- include.append(lookup_field)
-
- return [field.name for field in obj._meta.fields if field.name not in include]
+ return [field.name for field in obj._meta.fields if field.name != lookup_field]
class CreateModelMixin(object):
@@ -47,12 +36,10 @@ class CreateModelMixin(object):
Create a model instance.
"""
def create(self, request, *args, **kwargs):
- serializer = self.get_serializer(data=request.DATA, files=request.FILES)
+ serializer = self.get_serializer(data=request.DATA)
if serializer.is_valid():
- self.pre_save(serializer.object)
- self.object = serializer.save(force_insert=True)
- self.post_save(self.object, created=True)
+ self.object = serializer.save()
headers = self.get_success_headers(serializer.data)
return Response(serializer.data, status=status.HTTP_201_CREATED,
headers=headers)
@@ -70,24 +57,9 @@ class ListModelMixin(object):
"""
List a queryset.
"""
- empty_error = "Empty list and '%(class_name)s.allow_empty' is False."
-
def list(self, request, *args, **kwargs):
self.object_list = self.filter_queryset(self.get_queryset())
- # Default is to allow empty querysets. This can be altered by setting
- # `.allow_empty = False`, to raise 404 errors on empty querysets.
- if not self.allow_empty and not self.object_list:
- warnings.warn(
- 'The `allow_empty` parameter is deprecated. '
- 'To use `allow_empty=False` style behavior, You should override '
- '`get_queryset()` and explicitly raise a 404 on empty querysets.',
- DeprecationWarning
- )
- class_name = self.__class__.__name__
- error_msg = self.empty_error % {'class_name': class_name}
- raise Http404(error_msg)
-
# Switch between paginated or standard style responses
page = self.paginate_queryset(self.object_list)
if page is not None:
@@ -116,26 +88,20 @@ class UpdateModelMixin(object):
partial = kwargs.pop('partial', False)
self.object = self.get_object_or_none()
- serializer = self.get_serializer(self.object, data=request.DATA,
- files=request.FILES, partial=partial)
+ serializer = self.get_serializer(self.object, data=request.DATA, partial=partial)
if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
- try:
- self.pre_save(serializer.object)
- except ValidationError as err:
- # full_clean on model instance may be called in pre_save,
- # so we have to handle eventual errors.
- return Response(err.message_dict, status=status.HTTP_400_BAD_REQUEST)
+ lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
+ lookup_value = self.kwargs[lookup_url_kwarg]
+ extras = {self.lookup_field: lookup_value}
if self.object is None:
- self.object = serializer.save(force_insert=True)
- self.post_save(self.object, created=True)
+ self.object = serializer.save(extras=extras)
return Response(serializer.data, status=status.HTTP_201_CREATED)
- self.object = serializer.save(force_update=True)
- self.post_save(self.object, created=False)
+ self.object = serializer.save(extras=extras)
return Response(serializer.data, status=status.HTTP_200_OK)
def partial_update(self, request, *args, **kwargs):
@@ -161,26 +127,15 @@ class UpdateModelMixin(object):
"""
Set any attributes on the object that are implicit in the request.
"""
- # pk and/or slug attributes are implicit in the URL.
lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
- lookup = self.kwargs.get(lookup_url_kwarg, None)
- pk = self.kwargs.get(self.pk_url_kwarg, None)
- slug = self.kwargs.get(self.slug_url_kwarg, None)
- slug_field = slug and self.slug_field or None
+ lookup_value = self.kwargs[lookup_url_kwarg]
- if lookup:
- setattr(obj, self.lookup_field, lookup)
-
- if pk:
- setattr(obj, 'pk', pk)
-
- if slug:
- setattr(obj, slug_field, slug)
+ setattr(obj, self.lookup_field, lookup_value)
# Ensure we clean the attributes so that we don't eg return integer
# pk using a string representation, as provided by the url conf kwarg.
if hasattr(obj, 'full_clean'):
- exclude = _get_validation_exclusions(obj, pk, slug_field, self.lookup_field)
+ exclude = _get_validation_exclusions(obj, self.lookup_field)
obj.full_clean(exclude)
diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py
index d51ea929b..83ef97c5c 100644
--- a/rest_framework/pagination.py
+++ b/rest_framework/pagination.py
@@ -48,17 +48,17 @@ class DefaultObjectSerializer(serializers.Field):
super(DefaultObjectSerializer, self).__init__(source=source)
-class PaginationSerializerOptions(serializers.SerializerOptions):
- """
- An object that stores the options that may be provided to a
- pagination serializer by using the inner `Meta` class.
+# class PaginationSerializerOptions(serializers.SerializerOptions):
+# """
+# An object that stores the options that may be provided to a
+# pagination serializer by using the inner `Meta` class.
- Accessible on the instance as `serializer.opts`.
- """
- def __init__(self, meta):
- super(PaginationSerializerOptions, self).__init__(meta)
- self.object_serializer_class = getattr(meta, 'object_serializer_class',
- DefaultObjectSerializer)
+# Accessible on the instance as `serializer.opts`.
+# """
+# def __init__(self, meta):
+# super(PaginationSerializerOptions, self).__init__(meta)
+# self.object_serializer_class = getattr(meta, 'object_serializer_class',
+# DefaultObjectSerializer)
class BasePaginationSerializer(serializers.Serializer):
@@ -66,7 +66,7 @@ class BasePaginationSerializer(serializers.Serializer):
A base class for pagination serializers to inherit from,
to make implementing custom serializers more easy.
"""
- _options_class = PaginationSerializerOptions
+ # _options_class = PaginationSerializerOptions
results_field = 'results'
def __init__(self, *args, **kwargs):
diff --git a/rest_framework/relations.py b/rest_framework/relations.py
index 1acbdce26..e69de29bb 100644
--- a/rest_framework/relations.py
+++ b/rest_framework/relations.py
@@ -1,595 +0,0 @@
-"""
-Serializer fields that deal with relationships.
-
-These fields allow you to specify the style that should be used to represent
-model relationships, including hyperlinks, primary keys, or slugs.
-"""
-from __future__ import unicode_literals
-from django.core.exceptions import ObjectDoesNotExist, ValidationError
-from django.core.urlresolvers import resolve, get_script_prefix, NoReverseMatch
-from django import forms
-from django.db.models.fields import BLANK_CHOICE_DASH
-from django.forms import widgets
-from django.forms.models import ModelChoiceIterator
-from django.utils.translation import ugettext_lazy as _
-from rest_framework.fields import Field, WritableField, get_component, is_simple_callable
-from rest_framework.reverse import reverse
-from rest_framework.compat import urlparse
-from rest_framework.compat import smart_text
-import warnings
-
-
-# Relational fields
-
-# Not actually Writable, but subclasses may need to be.
-class RelatedField(WritableField):
- """
- Base class for related model fields.
-
- This represents a relationship using the unicode representation of the target.
- """
- widget = widgets.Select
- many_widget = widgets.SelectMultiple
- form_field_class = forms.ChoiceField
- many_form_field_class = forms.MultipleChoiceField
- null_values = (None, '', 'None')
-
- cache_choices = False
- empty_label = None
- read_only = True
- many = False
-
- def __init__(self, *args, **kwargs):
- queryset = kwargs.pop('queryset', None)
- self.many = kwargs.pop('many', self.many)
- if self.many:
- self.widget = self.many_widget
- self.form_field_class = self.many_form_field_class
-
- kwargs['read_only'] = kwargs.pop('read_only', self.read_only)
- super(RelatedField, self).__init__(*args, **kwargs)
-
- if not self.required:
- # Accessed in ModelChoiceIterator django/forms/models.py:1034
- # If set adds empty choice.
- self.empty_label = BLANK_CHOICE_DASH[0][1]
-
- self.queryset = queryset
-
- def initialize(self, parent, field_name):
- super(RelatedField, self).initialize(parent, field_name)
- if self.queryset is None and not self.read_only:
- manager = getattr(self.parent.opts.model, self.source or field_name)
- if hasattr(manager, 'related'): # Forward
- self.queryset = manager.related.model._default_manager.all()
- else: # Reverse
- self.queryset = manager.field.rel.to._default_manager.all()
-
- # We need this stuff to make form choices work...
-
- def prepare_value(self, obj):
- return self.to_native(obj)
-
- def label_from_instance(self, obj):
- """
- Return a readable representation for use with eg. select widgets.
- """
- desc = smart_text(obj)
- ident = smart_text(self.to_native(obj))
- if desc == ident:
- return desc
- return "%s - %s" % (desc, ident)
-
- def _get_queryset(self):
- return self._queryset
-
- def _set_queryset(self, queryset):
- self._queryset = queryset
- self.widget.choices = self.choices
-
- queryset = property(_get_queryset, _set_queryset)
-
- def _get_choices(self):
- # If self._choices is set, then somebody must have manually set
- # the property self.choices. In this case, just return self._choices.
- if hasattr(self, '_choices'):
- return self._choices
-
- # Otherwise, execute the QuerySet in self.queryset to determine the
- # choices dynamically. Return a fresh ModelChoiceIterator that has not been
- # consumed. Note that we're instantiating a new ModelChoiceIterator *each*
- # time _get_choices() is called (and, thus, each time self.choices is
- # accessed) so that we can ensure the QuerySet has not been consumed. This
- # construct might look complicated but it allows for lazy evaluation of
- # the queryset.
- return ModelChoiceIterator(self)
-
- def _set_choices(self, value):
- # Setting choices also sets the choices on the widget.
- # choices can be any iterable, but we call list() on it because
- # it will be consumed more than once.
- self._choices = self.widget.choices = list(value)
-
- choices = property(_get_choices, _set_choices)
-
- # Default value handling
-
- def get_default_value(self):
- default = super(RelatedField, self).get_default_value()
- if self.many and default is None:
- return []
- return default
-
- # Regular serializer stuff...
-
- def field_to_native(self, obj, field_name):
- try:
- if self.source == '*':
- return self.to_native(obj)
-
- source = self.source or field_name
- value = obj
-
- for component in source.split('.'):
- if value is None:
- break
- value = get_component(value, component)
- except ObjectDoesNotExist:
- return None
-
- if value is None:
- return None
-
- if self.many:
- if is_simple_callable(getattr(value, 'all', None)):
- return [self.to_native(item) for item in value.all()]
- else:
- # Also support non-queryset iterables.
- # This allows us to also support plain lists of related items.
- return [self.to_native(item) for item in value]
- return self.to_native(value)
-
- def field_from_native(self, data, files, field_name, into):
- if self.read_only:
- return
-
- try:
- if self.many:
- try:
- # Form data
- value = data.getlist(field_name)
- if value == [''] or value == []:
- raise KeyError
- except AttributeError:
- # Non-form data
- value = data[field_name]
- else:
- value = data[field_name]
- except KeyError:
- if self.partial:
- return
- value = self.get_default_value()
-
- if value in self.null_values:
- if self.required:
- raise ValidationError(self.error_messages['required'])
- into[(self.source or field_name)] = None
- elif self.many:
- into[(self.source or field_name)] = [self.from_native(item) for item in value]
- else:
- into[(self.source or field_name)] = self.from_native(value)
-
-
-# PrimaryKey relationships
-
-class PrimaryKeyRelatedField(RelatedField):
- """
- Represents a relationship as a pk value.
- """
- read_only = False
-
- default_error_messages = {
- 'does_not_exist': _("Invalid pk '%s' - object does not exist."),
- 'incorrect_type': _('Incorrect type. Expected pk value, received %s.'),
- }
-
- # TODO: Remove these field hacks...
- def prepare_value(self, obj):
- return self.to_native(obj.pk)
-
- def label_from_instance(self, obj):
- """
- Return a readable representation for use with eg. select widgets.
- """
- desc = smart_text(obj)
- ident = smart_text(self.to_native(obj.pk))
- if desc == ident:
- return desc
- return "%s - %s" % (desc, ident)
-
- # TODO: Possibly change this to just take `obj`, through prob less performant
- def to_native(self, pk):
- return pk
-
- def from_native(self, data):
- if self.queryset is None:
- raise Exception('Writable related fields must include a `queryset` argument')
-
- try:
- return self.queryset.get(pk=data)
- except ObjectDoesNotExist:
- msg = self.error_messages['does_not_exist'] % smart_text(data)
- raise ValidationError(msg)
- except (TypeError, ValueError):
- received = type(data).__name__
- msg = self.error_messages['incorrect_type'] % received
- raise ValidationError(msg)
-
- def field_to_native(self, obj, field_name):
- if self.many:
- # To-many relationship
-
- queryset = None
- if not self.source:
- # Prefer obj.serializable_value for performance reasons
- try:
- queryset = obj.serializable_value(field_name)
- except AttributeError:
- pass
- if queryset is None:
- # RelatedManager (reverse relationship)
- source = self.source or field_name
- queryset = obj
- for component in source.split('.'):
- if queryset is None:
- return []
- queryset = get_component(queryset, component)
-
- # Forward relationship
- if is_simple_callable(getattr(queryset, 'all', None)):
- return [self.to_native(item.pk) for item in queryset.all()]
- else:
- # Also support non-queryset iterables.
- # This allows us to also support plain lists of related items.
- return [self.to_native(item.pk) for item in queryset]
-
- # To-one relationship
- try:
- # Prefer obj.serializable_value for performance reasons
- pk = obj.serializable_value(self.source or field_name)
- except AttributeError:
- # RelatedObject (reverse relationship)
- try:
- pk = getattr(obj, self.source or field_name).pk
- except (ObjectDoesNotExist, AttributeError):
- return None
-
- # Forward relationship
- return self.to_native(pk)
-
-
-# Slug relationships
-
-class SlugRelatedField(RelatedField):
- """
- Represents a relationship using a unique field on the target.
- """
- read_only = False
-
- default_error_messages = {
- 'does_not_exist': _("Object with %s=%s does not exist."),
- 'invalid': _('Invalid value.'),
- }
-
- def __init__(self, *args, **kwargs):
- self.slug_field = kwargs.pop('slug_field', None)
- assert self.slug_field, 'slug_field is required'
- super(SlugRelatedField, self).__init__(*args, **kwargs)
-
- def to_native(self, obj):
- return getattr(obj, self.slug_field)
-
- def from_native(self, data):
- if self.queryset is None:
- raise Exception('Writable related fields must include a `queryset` argument')
-
- try:
- return self.queryset.get(**{self.slug_field: data})
- except ObjectDoesNotExist:
- raise ValidationError(self.error_messages['does_not_exist'] %
- (self.slug_field, smart_text(data)))
- except (TypeError, ValueError):
- msg = self.error_messages['invalid']
- raise ValidationError(msg)
-
-
-# Hyperlinked relationships
-
-class HyperlinkedRelatedField(RelatedField):
- """
- Represents a relationship using hyperlinking.
- """
- read_only = False
- lookup_field = 'pk'
-
- default_error_messages = {
- 'no_match': _('Invalid hyperlink - No URL match'),
- 'incorrect_match': _('Invalid hyperlink - Incorrect URL match'),
- 'configuration_error': _('Invalid hyperlink due to configuration error'),
- 'does_not_exist': _("Invalid hyperlink - object does not exist."),
- 'incorrect_type': _('Incorrect type. Expected url string, received %s.'),
- }
-
- # These are all deprecated
- pk_url_kwarg = 'pk'
- slug_field = 'slug'
- slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden
-
- def __init__(self, *args, **kwargs):
- try:
- self.view_name = kwargs.pop('view_name')
- except KeyError:
- raise ValueError("Hyperlinked field requires 'view_name' kwarg")
-
- self.lookup_field = kwargs.pop('lookup_field', self.lookup_field)
- self.format = kwargs.pop('format', None)
-
- # These are deprecated
- if 'pk_url_kwarg' in kwargs:
- msg = 'pk_url_kwarg is deprecated. Use lookup_field instead.'
- warnings.warn(msg, DeprecationWarning, stacklevel=2)
- if 'slug_url_kwarg' in kwargs:
- msg = 'slug_url_kwarg is deprecated. Use lookup_field instead.'
- warnings.warn(msg, DeprecationWarning, stacklevel=2)
- if 'slug_field' in kwargs:
- msg = 'slug_field is deprecated. Use lookup_field instead.'
- warnings.warn(msg, DeprecationWarning, stacklevel=2)
-
- self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg)
- self.slug_field = kwargs.pop('slug_field', self.slug_field)
- default_slug_kwarg = self.slug_url_kwarg or self.slug_field
- self.slug_url_kwarg = kwargs.pop('slug_url_kwarg', default_slug_kwarg)
-
- super(HyperlinkedRelatedField, self).__init__(*args, **kwargs)
-
- def get_url(self, obj, view_name, request, format):
- """
- Given an object, return the URL that hyperlinks to the object.
-
- May raise a `NoReverseMatch` if the `view_name` and `lookup_field`
- attributes are not configured to correctly match the URL conf.
- """
- lookup_field = getattr(obj, self.lookup_field)
- kwargs = {self.lookup_field: lookup_field}
- try:
- return reverse(view_name, kwargs=kwargs, request=request, format=format)
- except NoReverseMatch:
- pass
-
- if self.pk_url_kwarg != 'pk':
- # Only try pk if it has been explicitly set.
- # Otherwise, the default `lookup_field = 'pk'` has us covered.
- pk = obj.pk
- kwargs = {self.pk_url_kwarg: pk}
- try:
- return reverse(view_name, kwargs=kwargs, request=request, format=format)
- except NoReverseMatch:
- pass
-
- slug = getattr(obj, self.slug_field, None)
- if slug is not None:
- # Only try slug if it corresponds to an attribute on the object.
- kwargs = {self.slug_url_kwarg: slug}
- try:
- ret = reverse(view_name, kwargs=kwargs, request=request, format=format)
- if self.slug_field == 'slug' and self.slug_url_kwarg == 'slug':
- # If the lookup succeeds using the default slug params,
- # then `slug_field` is being used implicitly, and we
- # we need to warn about the pending deprecation.
- msg = 'Implicit slug field hyperlinked fields are deprecated.' \
- 'You should set `lookup_field=slug` on the HyperlinkedRelatedField.'
- warnings.warn(msg, DeprecationWarning, stacklevel=2)
- return ret
- except NoReverseMatch:
- pass
-
- raise NoReverseMatch()
-
- def get_object(self, queryset, view_name, view_args, view_kwargs):
- """
- Return the object corresponding to a matched URL.
-
- Takes the matched URL conf arguments, and the queryset, and should
- return an object instance, or raise an `ObjectDoesNotExist` exception.
- """
- lookup = view_kwargs.get(self.lookup_field, None)
- pk = view_kwargs.get(self.pk_url_kwarg, None)
- slug = view_kwargs.get(self.slug_url_kwarg, None)
-
- if lookup is not None:
- filter_kwargs = {self.lookup_field: lookup}
- elif pk is not None:
- filter_kwargs = {'pk': pk}
- elif slug is not None:
- filter_kwargs = {self.slug_field: slug}
- else:
- raise ObjectDoesNotExist()
-
- return queryset.get(**filter_kwargs)
-
- def to_native(self, obj):
- view_name = self.view_name
- request = self.context.get('request', None)
- format = self.format or self.context.get('format', None)
-
- assert request is not None, (
- "`HyperlinkedRelatedField` requires the request in the serializer "
- "context. Add `context={'request': request}` when instantiating "
- "the serializer."
- )
-
- # If the object has not yet been saved then we cannot hyperlink to it.
- if getattr(obj, 'pk', None) is None:
- return
-
- # Return the hyperlink, or error if incorrectly configured.
- try:
- return self.get_url(obj, view_name, request, format)
- except NoReverseMatch:
- msg = (
- 'Could not resolve URL for hyperlinked relationship using '
- 'view name "%s". You may have failed to include the related '
- 'model in your API, or incorrectly configured the '
- '`lookup_field` attribute on this field.'
- )
- raise Exception(msg % view_name)
-
- def from_native(self, value):
- # Convert URL -> model instance pk
- # TODO: Use values_list
- queryset = self.queryset
- if queryset is None:
- raise Exception('Writable related fields must include a `queryset` argument')
-
- try:
- http_prefix = value.startswith(('http:', 'https:'))
- except AttributeError:
- msg = self.error_messages['incorrect_type']
- raise ValidationError(msg % type(value).__name__)
-
- if http_prefix:
- # If needed convert absolute URLs to relative path
- value = urlparse.urlparse(value).path
- prefix = get_script_prefix()
- if value.startswith(prefix):
- value = '/' + value[len(prefix):]
-
- try:
- match = resolve(value)
- except Exception:
- raise ValidationError(self.error_messages['no_match'])
-
- if match.view_name != self.view_name:
- raise ValidationError(self.error_messages['incorrect_match'])
-
- try:
- return self.get_object(queryset, match.view_name,
- match.args, match.kwargs)
- except (ObjectDoesNotExist, TypeError, ValueError):
- raise ValidationError(self.error_messages['does_not_exist'])
-
-
-class HyperlinkedIdentityField(Field):
- """
- Represents the instance, or a property on the instance, using hyperlinking.
- """
- lookup_field = 'pk'
- read_only = True
-
- # These are all deprecated
- pk_url_kwarg = 'pk'
- slug_field = 'slug'
- slug_url_kwarg = None # Defaults to same as `slug_field` unless overridden
-
- def __init__(self, *args, **kwargs):
- try:
- self.view_name = kwargs.pop('view_name')
- except KeyError:
- msg = "HyperlinkedIdentityField requires 'view_name' argument"
- raise ValueError(msg)
-
- self.format = kwargs.pop('format', None)
- lookup_field = kwargs.pop('lookup_field', None)
- self.lookup_field = lookup_field or self.lookup_field
-
- # These are deprecated
- if 'pk_url_kwarg' in kwargs:
- msg = 'pk_url_kwarg is deprecated. Use lookup_field instead.'
- warnings.warn(msg, DeprecationWarning, stacklevel=2)
- if 'slug_url_kwarg' in kwargs:
- msg = 'slug_url_kwarg is deprecated. Use lookup_field instead.'
- warnings.warn(msg, DeprecationWarning, stacklevel=2)
- if 'slug_field' in kwargs:
- msg = 'slug_field is deprecated. Use lookup_field instead.'
- warnings.warn(msg, DeprecationWarning, stacklevel=2)
-
- self.slug_field = kwargs.pop('slug_field', self.slug_field)
- default_slug_kwarg = self.slug_url_kwarg or self.slug_field
- self.pk_url_kwarg = kwargs.pop('pk_url_kwarg', self.pk_url_kwarg)
- self.slug_url_kwarg = kwargs.pop('slug_url_kwarg', default_slug_kwarg)
-
- super(HyperlinkedIdentityField, self).__init__(*args, **kwargs)
-
- def field_to_native(self, obj, field_name):
- request = self.context.get('request', None)
- format = self.context.get('format', None)
- view_name = self.view_name
-
- assert request is not None, (
- "`HyperlinkedIdentityField` requires the request in the serializer"
- " context. Add `context={'request': request}` when instantiating "
- "the serializer."
- )
-
- # By default use whatever format is given for the current context
- # unless the target is a different type to the source.
- #
- # Eg. Consider a HyperlinkedIdentityField pointing from a json
- # representation to an html property of that representation...
- #
- # '/snippets/1/' should link to '/snippets/1/highlight/'
- # ...but...
- # '/snippets/1/.json' should link to '/snippets/1/highlight/.html'
- if format and self.format and self.format != format:
- format = self.format
-
- # Return the hyperlink, or error if incorrectly configured.
- try:
- return self.get_url(obj, view_name, request, format)
- except NoReverseMatch:
- msg = (
- 'Could not resolve URL for hyperlinked relationship using '
- 'view name "%s". You may have failed to include the related '
- 'model in your API, or incorrectly configured the '
- '`lookup_field` attribute on this field.'
- )
- raise Exception(msg % view_name)
-
- def get_url(self, obj, view_name, request, format):
- """
- Given an object, return the URL that hyperlinks to the object.
-
- May raise a `NoReverseMatch` if the `view_name` and `lookup_field`
- attributes are not configured to correctly match the URL conf.
- """
- lookup_field = getattr(obj, self.lookup_field, None)
- kwargs = {self.lookup_field: lookup_field}
-
- # Handle unsaved object case
- if lookup_field is None:
- return None
-
- try:
- return reverse(view_name, kwargs=kwargs, request=request, format=format)
- except NoReverseMatch:
- pass
-
- if self.pk_url_kwarg != 'pk':
- # Only try pk lookup if it has been explicitly set.
- # Otherwise, the default `lookup_field = 'pk'` has us covered.
- kwargs = {self.pk_url_kwarg: obj.pk}
- try:
- return reverse(view_name, kwargs=kwargs, request=request, format=format)
- except NoReverseMatch:
- pass
-
- slug = getattr(obj, self.slug_field, None)
- if slug:
- # Only use slug lookup if a slug field exists on the model
- kwargs = {self.slug_url_kwarg: slug}
- try:
- return reverse(view_name, kwargs=kwargs, request=request, format=format)
- except NoReverseMatch:
- pass
-
- raise NoReverseMatch()
diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py
index 748ebac94..e8935b012 100644
--- a/rest_framework/renderers.py
+++ b/rest_framework/renderers.py
@@ -458,7 +458,7 @@ class BrowsableAPIRenderer(BaseRenderer):
):
return
- serializer = view.get_serializer(instance=obj, data=data, files=files)
+ serializer = view.get_serializer(instance=obj, data=data)
serializer.is_valid()
data = serializer.data
@@ -579,10 +579,10 @@ class BrowsableAPIRenderer(BaseRenderer):
'available_formats': [renderer_cls.format for renderer_cls in view.renderer_classes],
'response_headers': response_headers,
- 'put_form': self.get_rendered_html_form(view, 'PUT', request),
- 'post_form': self.get_rendered_html_form(view, 'POST', request),
- 'delete_form': self.get_rendered_html_form(view, 'DELETE', request),
- 'options_form': self.get_rendered_html_form(view, 'OPTIONS', request),
+ #'put_form': self.get_rendered_html_form(view, 'PUT', request),
+ #'post_form': self.get_rendered_html_form(view, 'POST', request),
+ #'delete_form': self.get_rendered_html_form(view, 'DELETE', request),
+ #'options_form': self.get_rendered_html_form(view, 'OPTIONS', request),
'raw_data_put_form': raw_data_put_form,
'raw_data_post_form': raw_data_post_form,
diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py
index be8ad3f24..d121812d6 100644
--- a/rest_framework/serializers.py
+++ b/rest_framework/serializers.py
@@ -10,21 +10,14 @@ python primitives.
2. The process of marshalling between python primitives and request and
response content is handled by parsers and renderers.
"""
-from __future__ import unicode_literals
-import copy
-import datetime
-import inspect
-import types
-from decimal import Decimal
-from django.contrib.contenttypes.generic import GenericForeignKey
-from django.core.paginator import Page
from django.db import models
-from django.forms import widgets
from django.utils import six
-from django.utils.datastructures import SortedDict
-from django.core.exceptions import ObjectDoesNotExist
+from collections import namedtuple, OrderedDict
+from rest_framework.fields import empty, set_value, Field, SkipField, ValidationError
from rest_framework.settings import api_settings
-
+from rest_framework.utils import html
+import copy
+import inspect
# Note: We do the following so that users of the framework can use this style:
#
@@ -37,6 +30,253 @@ from rest_framework.relations import * # NOQA
from rest_framework.fields import * # NOQA
+FieldResult = namedtuple('FieldResult', ['field', 'value', 'error'])
+
+
+class BaseSerializer(Field):
+ def __init__(self, instance=None, data=None, **kwargs):
+ super(BaseSerializer, self).__init__(**kwargs)
+ self.instance = instance
+ self._initial_data = data
+
+ def to_native(self, data):
+ raise NotImplementedError()
+
+ def to_primative(self, instance):
+ raise NotImplementedError()
+
+ def update(self, instance):
+ raise NotImplementedError()
+
+ def create(self):
+ raise NotImplementedError()
+
+ def save(self, extras=None):
+ if extras is not None:
+ self._validated_data.update(extras)
+
+ if self.instance is not None:
+ self.update(self.instance)
+ else:
+ self.instance = self.create()
+
+ return self.instance
+
+ def is_valid(self):
+ try:
+ self._validated_data = self.to_native(self._initial_data)
+ except ValidationError as exc:
+ self._validated_data = {}
+ self._errors = exc.args[0]
+ return False
+ self._errors = {}
+ return True
+
+ @property
+ def data(self):
+ if not hasattr(self, '_data'):
+ if self.instance is not None:
+ self._data = self.to_primative(self.instance)
+ elif self._initial_data is not None:
+ self._data = {
+ field_name: field.get_value(self._initial_data)
+ for field_name, field in self.fields.items()
+ }
+ else:
+ self._data = self.get_initial()
+ return self._data
+
+ @property
+ def errors(self):
+ if not hasattr(self, '_errors'):
+ msg = 'You must call `.is_valid()` before accessing `.errors`.'
+ raise AssertionError(msg)
+ return self._errors
+
+ @property
+ def validated_data(self):
+ if not hasattr(self, '_validated_data'):
+ msg = 'You must call `.is_valid()` before accessing `.validated_data`.'
+ raise AssertionError(msg)
+ return self._validated_data
+
+
+class SerializerMetaclass(type):
+ """
+ This metaclass sets a dictionary named `base_fields` on the class.
+
+ Any fields included as attributes on either the class or it's superclasses
+ will be include in the `base_fields` dictionary.
+ """
+
+ @classmethod
+ def _get_fields(cls, bases, attrs):
+ fields = [(field_name, attrs.pop(field_name))
+ for field_name, obj in list(attrs.items())
+ if isinstance(obj, Field)]
+ fields.sort(key=lambda x: x[1]._creation_counter)
+
+ # If this class is subclassing another Serializer, add that Serializer's
+ # fields. Note that we loop over the bases in *reverse*. This is necessary
+ # in order to maintain the correct order of fields.
+ for base in bases[::-1]:
+ if hasattr(base, 'base_fields'):
+ fields = list(base.base_fields.items()) + fields
+
+ return OrderedDict(fields)
+
+ def __new__(cls, name, bases, attrs):
+ attrs['base_fields'] = cls._get_fields(bases, attrs)
+ return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs)
+
+
+@six.add_metaclass(SerializerMetaclass)
+class Serializer(BaseSerializer):
+
+ def __new__(cls, *args, **kwargs):
+ many = kwargs.pop('many', False)
+ if many:
+ class DynamicListSerializer(ListSerializer):
+ child = cls()
+ return DynamicListSerializer(*args, **kwargs)
+ return super(Serializer, cls).__new__(cls)
+
+ def __init__(self, *args, **kwargs):
+ kwargs.pop('context', None)
+ kwargs.pop('partial', None)
+ kwargs.pop('many', False)
+
+ super(Serializer, self).__init__(*args, **kwargs)
+
+ # Every new serializer is created with a clone of the field instances.
+ # This allows users to dynamically modify the fields on a serializer
+ # instance without affecting every other serializer class.
+ self.fields = self.get_fields()
+
+ # Setup all the child fields, to provide them with the current context.
+ for field_name, field in self.fields.items():
+ field.bind(field_name, self, self)
+
+ def get_fields(self):
+ return copy.deepcopy(self.base_fields)
+
+ def bind(self, field_name, parent, root):
+ # If the serializer is used as a field then when it becomes bound
+ # it also needs to bind all its child fields.
+ super(Serializer, self).bind(field_name, parent, root)
+ for field_name, field in self.fields.items():
+ field.bind(field_name, self, root)
+
+ def get_initial(self):
+ return {
+ field.field_name: field.get_initial()
+ for field in self.fields.values()
+ }
+
+ def get_value(self, dictionary):
+ # We override the default field access in order to support
+ # nested HTML forms.
+ if html.is_html_input(dictionary):
+ return html.parse_html_dict(dictionary, prefix=self.field_name)
+ return dictionary.get(self.field_name, empty)
+
+ def to_native(self, data):
+ """
+ Dict of native values <- Dict of primitive datatypes.
+ """
+ ret = {}
+ errors = {}
+ fields = [field for field in self.fields.values() if not field.read_only]
+
+ for field in fields:
+ primitive_value = field.get_value(data)
+ try:
+ validated_value = field.validate(primitive_value)
+ except ValidationError as exc:
+ errors[field.field_name] = str(exc)
+ except SkipField:
+ pass
+ else:
+ set_value(ret, field.source_attrs, validated_value)
+
+ if errors:
+ raise ValidationError(errors)
+
+ return ret
+
+ def to_primative(self, instance):
+ """
+ Object instance -> Dict of primitive datatypes.
+ """
+ ret = OrderedDict()
+ fields = [field for field in self.fields.values() if not field.write_only]
+
+ for field in fields:
+ native_value = field.get_attribute(instance)
+ ret[field.field_name] = field.to_primative(native_value)
+
+ return ret
+
+ def __iter__(self):
+ errors = self.errors if hasattr(self, '_errors') else {}
+ for field in self.fields.values():
+ value = self.data.get(field.field_name) if self.data else None
+ error = errors.get(field.field_name)
+ yield FieldResult(field, value, error)
+
+
+class ListSerializer(BaseSerializer):
+ child = None
+ initial = []
+
+ def __init__(self, *args, **kwargs):
+ self.child = kwargs.pop('child', copy.deepcopy(self.child))
+ assert self.child is not None, '`child` is a required argument.'
+
+ kwargs.pop('context', None)
+ kwargs.pop('partial', None)
+
+ super(ListSerializer, self).__init__(*args, **kwargs)
+ self.child.bind('', self, self)
+
+ def bind(self, field_name, parent, root):
+ # If the list is used as a field then it needs to provide
+ # the current context to the child serializer.
+ super(ListSerializer, self).bind(field_name, parent, root)
+ self.child.bind(field_name, self, root)
+
+ def get_value(self, dictionary):
+ # We override the default field access in order to support
+ # lists in HTML forms.
+ if is_html_input(dictionary):
+ return html.parse_html_list(dictionary, prefix=self.field_name)
+ return dictionary.get(self.field_name, empty)
+
+ def to_native(self, data):
+ """
+ List of dicts of native values <- List of dicts of primitive datatypes.
+ """
+ if html.is_html_input(data):
+ data = html.parse_html_list(data)
+
+ return [self.child.validate(item) for item in data]
+
+ def to_primative(self, data):
+ """
+ List of object instances -> List of dicts of primitive datatypes.
+ """
+ return [self.child.to_primative(item) for item in data]
+
+ def create(self, attrs_list):
+ return [self.child.create(attrs) for attrs in attrs_list]
+
+ def save(self):
+ if self.instance is not None:
+ self.update(self.instance, self.validated_data)
+ self.instance = self.create(self.validated_data)
+ return self.instance
+
+
def _resolve_model(obj):
"""
Resolve supplied `obj` to a Django model class.
@@ -58,614 +298,71 @@ def _resolve_model(obj):
raise ValueError("{0} is not a Django model".format(obj))
-def pretty_name(name):
- """Converts 'first_name' to 'First name'"""
- if not name:
- return ''
- return name.replace('_', ' ').capitalize()
-
-
-class RelationsList(list):
- _deleted = []
-
-
-class NestedValidationError(ValidationError):
- """
- The default ValidationError behavior is to stringify each item in the list
- if the messages are a list of error messages.
-
- In the case of nested serializers, where the parent has many children,
- then the child's `serializer.errors` will be a list of dicts. In the case
- of a single child, the `serializer.errors` will be a dict.
-
- We need to override the default behavior to get properly nested error dicts.
- """
-
- def __init__(self, message):
- if isinstance(message, dict):
- self._messages = [message]
- else:
- self._messages = message
-
- @property
- def messages(self):
- return self._messages
-
-
-class DictWithMetadata(dict):
- """
- A dict-like object, that can have additional properties attached.
- """
- def __getstate__(self):
- """
- Used by pickle (e.g., caching).
- Overridden to remove the metadata from the dict, since it shouldn't be
- pickled and may in some instances be unpickleable.
- """
- return dict(self)
-
-
-class SortedDictWithMetadata(SortedDict):
- """
- A sorted dict-like object, that can have additional properties attached.
- """
- def __getstate__(self):
- """
- Used by pickle (e.g., caching).
- Overriden to remove the metadata from the dict, since it shouldn't be
- pickle and may in some instances be unpickleable.
- """
- return SortedDict(self).__dict__
-
-
-def _is_protected_type(obj):
- """
- True if the object is a native datatype that does not need to
- be serialized further.
- """
- return isinstance(obj, (
- types.NoneType,
- int, long,
- datetime.datetime, datetime.date, datetime.time,
- float, Decimal,
- basestring)
- )
-
-
-def _get_declared_fields(bases, attrs):
- """
- Create a list of serializer field instances from the passed in 'attrs',
- plus any fields on the base classes (in 'bases').
-
- Note that all fields from the base classes are used.
- """
- fields = [(field_name, attrs.pop(field_name))
- for field_name, obj in list(six.iteritems(attrs))
- if isinstance(obj, Field)]
- fields.sort(key=lambda x: x[1].creation_counter)
-
- # If this class is subclassing another Serializer, add that Serializer's
- # fields. Note that we loop over the bases in *reverse*. This is necessary
- # in order to maintain the correct order of fields.
- for base in bases[::-1]:
- if hasattr(base, 'base_fields'):
- fields = list(base.base_fields.items()) + fields
-
- return SortedDict(fields)
-
-
-class SerializerMetaclass(type):
- def __new__(cls, name, bases, attrs):
- attrs['base_fields'] = _get_declared_fields(bases, attrs)
- return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs)
-
-
-class SerializerOptions(object):
- """
- Meta class options for Serializer
- """
- def __init__(self, meta):
- self.depth = getattr(meta, 'depth', 0)
- self.fields = getattr(meta, 'fields', ())
- self.exclude = getattr(meta, 'exclude', ())
-
-
-class BaseSerializer(WritableField):
- """
- This is the Serializer implementation.
- We need to implement it as `BaseSerializer` due to metaclass magicks.
- """
- class Meta(object):
- pass
-
- _options_class = SerializerOptions
- _dict_class = SortedDictWithMetadata
-
- def __init__(self, instance=None, data=None, files=None,
- context=None, partial=False, many=False,
- allow_add_remove=False, **kwargs):
- super(BaseSerializer, self).__init__(**kwargs)
- self.opts = self._options_class(self.Meta)
- self.parent = None
- self.root = None
- self.partial = partial
- self.many = many
- self.allow_add_remove = allow_add_remove
-
- self.context = context or {}
-
- self.init_data = data
- self.init_files = files
- self.object = instance
- self.fields = self.get_fields()
-
- self._data = None
- self._files = None
- self._errors = None
-
- if many and instance is not None and not hasattr(instance, '__iter__'):
- raise ValueError('instance should be a queryset or other iterable with many=True')
-
- if allow_add_remove and not many:
- raise ValueError('allow_add_remove should only be used for bulk updates, but you have not set many=True')
-
- #####
- # Methods to determine which fields to use when (de)serializing objects.
-
- def get_default_fields(self):
- """
- Return the complete set of default fields for the object, as a dict.
- """
- return {}
-
- def get_fields(self):
- """
- Returns the complete set of fields for the object as a dict.
-
- This will be the set of any explicitly declared fields,
- plus the set of fields returned by get_default_fields().
- """
- ret = SortedDict()
-
- # Get the explicitly declared fields
- base_fields = copy.deepcopy(self.base_fields)
- for key, field in base_fields.items():
- ret[key] = field
-
- # Add in the default fields
- default_fields = self.get_default_fields()
- for key, val in default_fields.items():
- if key not in ret:
- ret[key] = val
-
- # If 'fields' is specified, use those fields, in that order.
- if self.opts.fields:
- assert isinstance(self.opts.fields, (list, tuple)), '`fields` must be a list or tuple'
- new = SortedDict()
- for key in self.opts.fields:
- new[key] = ret[key]
- ret = new
-
- # Remove anything in 'exclude'
- if self.opts.exclude:
- assert isinstance(self.opts.exclude, (list, tuple)), '`exclude` must be a list or tuple'
- for key in self.opts.exclude:
- ret.pop(key, None)
-
- for key, field in ret.items():
- field.initialize(parent=self, field_name=key)
-
- return ret
-
- #####
- # Methods to convert or revert from objects <--> primitive representations.
-
- def get_field_key(self, field_name):
- """
- Return the key that should be used for a given field.
- """
- return field_name
-
- def restore_fields(self, data, files):
- """
- Core of deserialization, together with `restore_object`.
- Converts a dictionary of data into a dictionary of deserialized fields.
- """
- reverted_data = {}
-
- if data is not None and not isinstance(data, dict):
- self._errors['non_field_errors'] = ['Invalid data']
- return None
-
- for field_name, field in self.fields.items():
- field.initialize(parent=self, field_name=field_name)
- try:
- field.field_from_native(data, files, field_name, reverted_data)
- except ValidationError as err:
- self._errors[field_name] = list(err.messages)
-
- return reverted_data
-
- def perform_validation(self, attrs):
- """
- Run `validate_()` and `validate()` methods on the serializer
- """
- for field_name, field in self.fields.items():
- if field_name in self._errors:
- continue
-
- source = field.source or field_name
- if self.partial and source not in attrs:
- continue
- try:
- validate_method = getattr(self, 'validate_%s' % field_name, None)
- if validate_method:
- attrs = validate_method(attrs, source)
- except ValidationError as err:
- self._errors[field_name] = self._errors.get(field_name, []) + list(err.messages)
-
- # If there are already errors, we don't run .validate() because
- # field-validation failed and thus `attrs` may not be complete.
- # which in turn can cause inconsistent validation errors.
- if not self._errors:
- try:
- attrs = self.validate(attrs)
- except ValidationError as err:
- if hasattr(err, 'message_dict'):
- for field_name, error_messages in err.message_dict.items():
- self._errors[field_name] = self._errors.get(field_name, []) + list(error_messages)
- elif hasattr(err, 'messages'):
- self._errors['non_field_errors'] = err.messages
-
- return attrs
-
- def validate(self, attrs):
- """
- Stub method, to be overridden in Serializer subclasses
- """
- return attrs
-
- def restore_object(self, attrs, instance=None):
- """
- Deserialize a dictionary of attributes into an object instance.
- You should override this method to control how deserialized objects
- are instantiated.
- """
- if instance is not None:
- instance.update(attrs)
- return instance
- return attrs
-
- def to_native(self, obj):
- """
- Serialize objects -> primitives.
- """
- ret = self._dict_class()
- ret.fields = self._dict_class()
-
- for field_name, field in self.fields.items():
- if field.read_only and obj is None:
- continue
- field.initialize(parent=self, field_name=field_name)
- key = self.get_field_key(field_name)
- value = field.field_to_native(obj, field_name)
- method = getattr(self, 'transform_%s' % field_name, None)
- if callable(method):
- value = method(obj, value)
- if not getattr(field, 'write_only', False):
- ret[key] = value
- ret.fields[key] = self.augment_field(field, field_name, key, value)
-
- return ret
-
- def from_native(self, data, files=None):
- """
- Deserialize primitives -> objects.
- """
- self._errors = {}
-
- if data is not None or files is not None:
- attrs = self.restore_fields(data, files)
- if attrs is not None:
- attrs = self.perform_validation(attrs)
- else:
- self._errors['non_field_errors'] = ['No input provided']
-
- if not self._errors:
- return self.restore_object(attrs, instance=getattr(self, 'object', None))
-
- def augment_field(self, field, field_name, key, value):
- # This horrible stuff is to manage serializers rendering to HTML
- field._errors = self._errors.get(key) if self._errors else None
- field._name = field_name
- field._value = self.init_data.get(key) if self._errors and self.init_data else value
- if not field.label:
- field.label = pretty_name(key)
- return field
-
- def field_to_native(self, obj, field_name):
- """
- Override default so that the serializer can be used as a nested field
- across relationships.
- """
- if self.write_only:
- return None
-
- if self.source == '*':
- return self.to_native(obj)
-
- # Get the raw field value
- try:
- source = self.source or field_name
- value = obj
-
- for component in source.split('.'):
- if value is None:
- break
- value = get_component(value, component)
- except ObjectDoesNotExist:
- return None
-
- if is_simple_callable(getattr(value, 'all', None)):
- return [self.to_native(item) for item in value.all()]
-
- if value is None:
- return None
-
- if self.many:
- return [self.to_native(item) for item in value]
- return self.to_native(value)
-
- def field_from_native(self, data, files, field_name, into):
- """
- Override default so that the serializer can be used as a writable
- nested field across relationships.
- """
- if self.read_only:
- return
-
- try:
- value = data[field_name]
- except KeyError:
- if self.default is not None and not self.partial:
- # Note: partial updates shouldn't set defaults
- value = copy.deepcopy(self.default)
- else:
- if self.required:
- raise ValidationError(self.error_messages['required'])
- return
-
- if self.source == '*':
- if value:
- reverted_data = self.restore_fields(value, {})
- if not self._errors:
- into.update(reverted_data)
- else:
- if value in (None, ''):
- into[(self.source or field_name)] = None
- else:
- # Set the serializer object if it exists
- obj = get_component(self.parent.object, self.source or field_name) if self.parent.object else None
-
- # If we have a model manager or similar object then we need
- # to iterate through each instance.
- if (
- self.many and
- not hasattr(obj, '__iter__') and
- is_simple_callable(getattr(obj, 'all', None))
- ):
- obj = obj.all()
-
- kwargs = {
- 'instance': obj,
- 'data': value,
- 'context': self.context,
- 'partial': self.partial,
- 'many': self.many,
- 'allow_add_remove': self.allow_add_remove
- }
- serializer = self.__class__(**kwargs)
-
- if serializer.is_valid():
- into[self.source or field_name] = serializer.object
- else:
- # Propagate errors up to our parent
- raise NestedValidationError(serializer.errors)
-
- def get_identity(self, data):
- """
- This hook is required for bulk update.
- It is used to determine the canonical identity of a given object.
-
- Note that the data has not been validated at this point, so we need
- to make sure that we catch any cases of incorrect datatypes being
- passed to this method.
- """
- try:
- return data.get('id', None)
- except AttributeError:
- return None
-
- @property
- def errors(self):
- """
- Run deserialization and return error data,
- setting self.object if no errors occurred.
- """
- if self._errors is None:
- data, files = self.init_data, self.init_files
-
- if self.many is not None:
- many = self.many
- else:
- many = hasattr(data, '__iter__') and not isinstance(data, (Page, dict, six.text_type))
- if many:
- warnings.warn('Implicit list/queryset serialization is deprecated. '
- 'Use the `many=True` flag when instantiating the serializer.',
- DeprecationWarning, stacklevel=3)
-
- if many:
- ret = RelationsList()
- errors = []
- update = self.object is not None
-
- if update:
- # If this is a bulk update we need to map all the objects
- # to a canonical identity so we can determine which
- # individual object is being updated for each item in the
- # incoming data
- objects = self.object
- identities = [self.get_identity(self.to_native(obj)) for obj in objects]
- identity_to_objects = dict(zip(identities, objects))
-
- if hasattr(data, '__iter__') and not isinstance(data, (dict, six.text_type)):
- for item in data:
- if update:
- # Determine which object we're updating
- identity = self.get_identity(item)
- self.object = identity_to_objects.pop(identity, None)
- if self.object is None and not self.allow_add_remove:
- ret.append(None)
- errors.append({'non_field_errors': ['Cannot create a new item, only existing items may be updated.']})
- continue
-
- ret.append(self.from_native(item, None))
- errors.append(self._errors)
-
- if update and self.allow_add_remove:
- ret._deleted = identity_to_objects.values()
-
- self._errors = any(errors) and errors or []
- else:
- self._errors = {'non_field_errors': ['Expected a list of items.']}
- else:
- ret = self.from_native(data, files)
-
- if not self._errors:
- self.object = ret
-
- return self._errors
-
- def is_valid(self):
- return not self.errors
-
- @property
- def data(self):
- """
- Returns the serialized data on the serializer.
- """
- if self._data is None:
- obj = self.object
-
- if self.many is not None:
- many = self.many
- else:
- many = hasattr(obj, '__iter__') and not isinstance(obj, (Page, dict))
- if many:
- warnings.warn('Implicit list/queryset serialization is deprecated. '
- 'Use the `many=True` flag when instantiating the serializer.',
- DeprecationWarning, stacklevel=2)
-
- if many:
- self._data = [self.to_native(item) for item in obj]
- else:
- self._data = self.to_native(obj)
-
- return self._data
-
- def save_object(self, obj, **kwargs):
- obj.save(**kwargs)
-
- def delete_object(self, obj):
- obj.delete()
-
- def save(self, **kwargs):
- """
- Save the deserialized object and return it.
- """
- # Clear cached _data, which may be invalidated by `save()`
- self._data = None
-
- if isinstance(self.object, list):
- [self.save_object(item, **kwargs) for item in self.object]
-
- if self.object._deleted:
- [self.delete_object(item) for item in self.object._deleted]
- else:
- self.save_object(self.object, **kwargs)
-
- return self.object
-
- def metadata(self):
- """
- Return a dictionary of metadata about the fields on the serializer.
- Useful for things like responding to OPTIONS requests, or generating
- API schemas for auto-documentation.
- """
- return SortedDict(
- [
- (field_name, field.metadata())
- for field_name, field in six.iteritems(self.fields)
- ]
- )
-
-
-class Serializer(six.with_metaclass(SerializerMetaclass, BaseSerializer)):
- pass
-
-
-class ModelSerializerOptions(SerializerOptions):
+class ModelSerializerOptions(object):
"""
Meta class options for ModelSerializer
"""
def __init__(self, meta):
- super(ModelSerializerOptions, self).__init__(meta)
- self.model = getattr(meta, 'model', None)
- self.read_only_fields = getattr(meta, 'read_only_fields', ())
- self.write_only_fields = getattr(meta, 'write_only_fields', ())
+ self.model = getattr(meta, 'model')
+ self.fields = getattr(meta, 'fields', ())
+ self.depth = getattr(meta, 'depth', 0)
class ModelSerializer(Serializer):
- """
- A serializer that deals with model instances and querysets.
- """
- _options_class = ModelSerializerOptions
-
field_mapping = {
models.AutoField: IntegerField,
- models.FloatField: FloatField,
+ # models.FloatField: FloatField,
models.IntegerField: IntegerField,
models.PositiveIntegerField: IntegerField,
models.SmallIntegerField: IntegerField,
models.PositiveSmallIntegerField: IntegerField,
- models.DateTimeField: DateTimeField,
- models.DateField: DateField,
- models.TimeField: TimeField,
- models.DecimalField: DecimalField,
- models.EmailField: EmailField,
+ # models.DateTimeField: DateTimeField,
+ # models.DateField: DateField,
+ # models.TimeField: TimeField,
+ # models.DecimalField: DecimalField,
+ # models.EmailField: EmailField,
models.CharField: CharField,
- models.URLField: URLField,
- models.SlugField: SlugField,
+ # models.URLField: URLField,
+ # models.SlugField: SlugField,
models.TextField: CharField,
models.CommaSeparatedIntegerField: CharField,
models.BooleanField: BooleanField,
models.NullBooleanField: BooleanField,
- models.FileField: FileField,
- models.ImageField: ImageField,
+ # models.FileField: FileField,
+ # models.ImageField: ImageField,
}
+ _options_class = ModelSerializerOptions
+
+ def __init__(self, *args, **kwargs):
+ self.opts = self._options_class(self.Meta)
+ super(ModelSerializer, self).__init__(*args, **kwargs)
+
+ def get_fields(self):
+ # Get the explicitly declared fields.
+ fields = copy.deepcopy(self.base_fields)
+
+ # Add in the default fields.
+ for key, val in self.get_default_fields().items():
+ if key not in fields:
+ fields[key] = val
+
+ # If `fields` is set on the `Meta` class,
+ # then use only those fields, and in that order.
+ if self.opts.fields:
+ fields = OrderedDict([
+ (key, fields[key]) for key in self.opts.fields
+ ])
+
+ return fields
+
def get_default_fields(self):
"""
Return all the fields that should be serialized for the model.
"""
-
cls = self.opts.model
- assert cls is not None, (
- "Serializer class '%s' is missing 'model' Meta option" %
- self.__class__.__name__
- )
opts = cls._meta.concrete_model._meta
- ret = SortedDict()
+ ret = OrderedDict()
nested = bool(self.opts.depth)
# Deal with adding the primary key field
@@ -694,29 +391,9 @@ class ModelSerializer(Serializer):
has_through_model = True
if model_field.rel and nested:
- if len(inspect.getargspec(self.get_nested_field).args) == 2:
- warnings.warn(
- 'The `get_nested_field(model_field)` call signature '
- 'is deprecated. '
- 'Use `get_nested_field(model_field, related_model, '
- 'to_many) instead',
- DeprecationWarning
- )
- field = self.get_nested_field(model_field)
- else:
- field = self.get_nested_field(model_field, related_model, to_many)
+ field = self.get_nested_field(model_field, related_model, to_many)
elif model_field.rel:
- if len(inspect.getargspec(self.get_nested_field).args) == 3:
- warnings.warn(
- 'The `get_related_field(model_field, to_many)` call '
- 'signature is deprecated. '
- 'Use `get_related_field(model_field, related_model, '
- 'to_many) instead',
- DeprecationWarning
- )
- field = self.get_related_field(model_field, to_many=to_many)
- else:
- field = self.get_related_field(model_field, related_model, to_many)
+ field = self.get_related_field(model_field, related_model, to_many)
else:
field = self.get_field(model_field)
@@ -763,38 +440,6 @@ class ModelSerializer(Serializer):
ret[accessor_name] = field
- # Ensure that 'read_only_fields' is an iterable
- assert isinstance(self.opts.read_only_fields, (list, tuple)), '`read_only_fields` must be a list or tuple'
-
- # Add the `read_only` flag to any fields that have been specified
- # in the `read_only_fields` option
- for field_name in self.opts.read_only_fields:
- assert field_name not in self.base_fields.keys(), (
- "field '%s' on serializer '%s' specified in "
- "`read_only_fields`, but also added "
- "as an explicit field. Remove it from `read_only_fields`." %
- (field_name, self.__class__.__name__))
- assert field_name in ret, (
- "Non-existant field '%s' specified in `read_only_fields` "
- "on serializer '%s'." %
- (field_name, self.__class__.__name__))
- ret[field_name].read_only = True
-
- # Ensure that 'write_only_fields' is an iterable
- assert isinstance(self.opts.write_only_fields, (list, tuple)), '`write_only_fields` must be a list or tuple'
-
- for field_name in self.opts.write_only_fields:
- assert field_name not in self.base_fields.keys(), (
- "field '%s' on serializer '%s' specified in "
- "`write_only_fields`, but also added "
- "as an explicit field. Remove it from `write_only_fields`." %
- (field_name, self.__class__.__name__))
- assert field_name in ret, (
- "Non-existant field '%s' specified in `write_only_fields` "
- "on serializer '%s'." %
- (field_name, self.__class__.__name__))
- ret[field_name].write_only = True
-
return ret
def get_pk_field(self, model_field):
@@ -825,28 +470,24 @@ class ModelSerializer(Serializer):
# TODO: filter queryset using:
# .using(db).complex_filter(self.rel.limit_choices_to)
- kwargs = {
- 'queryset': related_model._default_manager,
- 'many': to_many
- }
+ kwargs = {}
+ # 'queryset': related_model._default_manager,
+ # 'many': to_many
+ # }
if model_field:
kwargs['required'] = not(model_field.null or model_field.blank)
- if model_field.help_text is not None:
- kwargs['help_text'] = model_field.help_text
+ # if model_field.help_text is not None:
+ # kwargs['help_text'] = model_field.help_text
if model_field.verbose_name is not None:
kwargs['label'] = model_field.verbose_name
-
if not model_field.editable:
kwargs['read_only'] = True
-
if model_field.verbose_name is not None:
kwargs['label'] = model_field.verbose_name
- if model_field.help_text is not None:
- kwargs['help_text'] = model_field.help_text
-
- return PrimaryKeyRelatedField(**kwargs)
+ return IntegerField(**kwargs)
+ # TODO: return PrimaryKeyRelatedField(**kwargs)
def get_field(self, model_field):
"""
@@ -869,8 +510,8 @@ class ModelSerializer(Serializer):
if model_field.verbose_name is not None:
kwargs['label'] = model_field.verbose_name
- if model_field.help_text is not None:
- kwargs['help_text'] = model_field.help_text
+ # if model_field.help_text is not None:
+ # kwargs['help_text'] = model_field.help_text
# TODO: TypedChoiceField?
if model_field.flatchoices: # This ModelField contains choices
@@ -880,7 +521,7 @@ class ModelSerializer(Serializer):
return ChoiceField(**kwargs)
# put this below the ChoiceField because min_value isn't a valid initializer
- if issubclass(model_field.__class__, models.PositiveIntegerField) or\
+ if issubclass(model_field.__class__, models.PositiveIntegerField) or \
issubclass(model_field.__class__, models.PositiveSmallIntegerField):
kwargs['min_value'] = 0
@@ -888,170 +529,27 @@ class ModelSerializer(Serializer):
issubclass(model_field.__class__, (models.CharField, models.TextField)):
kwargs['allow_none'] = True
- attribute_dict = {
- models.CharField: ['max_length'],
- models.CommaSeparatedIntegerField: ['max_length'],
- models.DecimalField: ['max_digits', 'decimal_places'],
- models.EmailField: ['max_length'],
- models.FileField: ['max_length'],
- models.ImageField: ['max_length'],
- models.SlugField: ['max_length'],
- models.URLField: ['max_length'],
- }
+ # attribute_dict = {
+ # models.CharField: ['max_length'],
+ # models.CommaSeparatedIntegerField: ['max_length'],
+ # models.DecimalField: ['max_digits', 'decimal_places'],
+ # models.EmailField: ['max_length'],
+ # models.FileField: ['max_length'],
+ # models.ImageField: ['max_length'],
+ # models.SlugField: ['max_length'],
+ # models.URLField: ['max_length'],
+ # }
- if model_field.__class__ in attribute_dict:
- attributes = attribute_dict[model_field.__class__]
- for attribute in attributes:
- kwargs.update({attribute: getattr(model_field, attribute)})
+ # if model_field.__class__ in attribute_dict:
+ # attributes = attribute_dict[model_field.__class__]
+ # for attribute in attributes:
+ # kwargs.update({attribute: getattr(model_field, attribute)})
try:
return self.field_mapping[model_field.__class__](**kwargs)
except KeyError:
- return ModelField(model_field=model_field, **kwargs)
-
- def get_validation_exclusions(self, instance=None):
- """
- Return a list of field names to exclude from model validation.
- """
- cls = self.opts.model
- opts = cls._meta.concrete_model._meta
- exclusions = [field.name for field in opts.fields + opts.many_to_many]
-
- for field_name, field in self.fields.items():
- field_name = field.source or field_name
- if (
- field_name in exclusions
- and not field.read_only
- and (field.required or hasattr(instance, field_name))
- and not isinstance(field, Serializer)
- ):
- exclusions.remove(field_name)
- return exclusions
-
- def full_clean(self, instance):
- """
- Perform Django's full_clean, and populate the `errors` dictionary
- if any validation errors occur.
-
- Note that we don't perform this inside the `.restore_object()` method,
- so that subclasses can override `.restore_object()`, and still get
- the full_clean validation checking.
- """
- try:
- instance.full_clean(exclude=self.get_validation_exclusions(instance))
- except ValidationError as err:
- self._errors = err.message_dict
- return None
- return instance
-
- def restore_object(self, attrs, instance=None):
- """
- Restore the model instance.
- """
- m2m_data = {}
- related_data = {}
- nested_forward_relations = {}
- meta = self.opts.model._meta
-
- # Reverse fk or one-to-one relations
- for (obj, model) in meta.get_all_related_objects_with_model():
- field_name = obj.get_accessor_name()
- if field_name in attrs:
- related_data[field_name] = attrs.pop(field_name)
-
- # Reverse m2m relations
- for (obj, model) in meta.get_all_related_m2m_objects_with_model():
- field_name = obj.get_accessor_name()
- if field_name in attrs:
- m2m_data[field_name] = attrs.pop(field_name)
-
- # Forward m2m relations
- for field in meta.many_to_many + meta.virtual_fields:
- if isinstance(field, GenericForeignKey):
- continue
- if field.name in attrs:
- m2m_data[field.name] = attrs.pop(field.name)
-
- # Nested forward relations - These need to be marked so we can save
- # them before saving the parent model instance.
- for field_name in attrs.keys():
- if isinstance(self.fields.get(field_name, None), Serializer):
- nested_forward_relations[field_name] = attrs[field_name]
-
- # Create an empty instance of the model
- if instance is None:
- instance = self.opts.model()
-
- for key, val in attrs.items():
- try:
- setattr(instance, key, val)
- except ValueError:
- self._errors[key] = [self.error_messages['required']]
-
- # Any relations that cannot be set until we've
- # saved the model get hidden away on these
- # private attributes, so we can deal with them
- # at the point of save.
- instance._related_data = related_data
- instance._m2m_data = m2m_data
- instance._nested_forward_relations = nested_forward_relations
-
- return instance
-
- def from_native(self, data, files):
- """
- Override the default method to also include model field validation.
- """
- instance = super(ModelSerializer, self).from_native(data, files)
- if not self._errors:
- return self.full_clean(instance)
-
- def save_object(self, obj, **kwargs):
- """
- Save the deserialized object.
- """
- if getattr(obj, '_nested_forward_relations', None):
- # Nested relationships need to be saved before we can save the
- # parent instance.
- for field_name, sub_object in obj._nested_forward_relations.items():
- if sub_object:
- self.save_object(sub_object)
- setattr(obj, field_name, sub_object)
-
- obj.save(**kwargs)
-
- if getattr(obj, '_m2m_data', None):
- for accessor_name, object_list in obj._m2m_data.items():
- setattr(obj, accessor_name, object_list)
- del(obj._m2m_data)
-
- if getattr(obj, '_related_data', None):
- related_fields = dict([
- (field.get_accessor_name(), field)
- for field, model
- in obj._meta.get_all_related_objects_with_model()
- ])
- for accessor_name, related in obj._related_data.items():
- if isinstance(related, RelationsList):
- # Nested reverse fk relationship
- for related_item in related:
- fk_field = related_fields[accessor_name].field.name
- setattr(related_item, fk_field, obj)
- self.save_object(related_item)
-
- # Delete any removed objects
- if related._deleted:
- [self.delete_object(item) for item in related._deleted]
-
- elif isinstance(related, models.Model):
- # Nested reverse one-one relationship
- fk_field = obj._meta.get_field_by_name(accessor_name)[0].field.name
- setattr(related, fk_field, obj)
- self.save_object(related)
- else:
- # Reverse FK or reverse one-one
- setattr(obj, accessor_name, related)
- del(obj._related_data)
+ # TODO: Change this to `return ModelField(model_field=model_field, **kwargs)`
+ return CharField(**kwargs)
class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
@@ -1066,14 +564,10 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
class HyperlinkedModelSerializer(ModelSerializer):
- """
- A subclass of ModelSerializer that uses hyperlinked relationships,
- instead of primary key relationships.
- """
_options_class = HyperlinkedModelSerializerOptions
_default_view_name = '%(model_name)s-detail'
- _hyperlink_field_class = HyperlinkedRelatedField
- _hyperlink_identify_field_class = HyperlinkedIdentityField
+ #_hyperlink_field_class = HyperlinkedRelatedField
+ #_hyperlink_identify_field_class = HyperlinkedIdentityField
def get_default_fields(self):
fields = super(HyperlinkedModelSerializer, self).get_default_fields()
@@ -1081,15 +575,15 @@ class HyperlinkedModelSerializer(ModelSerializer):
if self.opts.view_name is None:
self.opts.view_name = self._get_default_view_name(self.opts.model)
- if self.opts.url_field_name not in fields:
- url_field = self._hyperlink_identify_field_class(
- view_name=self.opts.view_name,
- lookup_field=self.opts.lookup_field
- )
- ret = self._dict_class()
- ret[self.opts.url_field_name] = url_field
- ret.update(fields)
- fields = ret
+ # if self.opts.url_field_name not in fields:
+ # url_field = self._hyperlink_identify_field_class(
+ # view_name=self.opts.view_name,
+ # lookup_field=self.opts.lookup_field
+ # )
+ # ret = self._dict_class()
+ # ret[self.opts.url_field_name] = url_field
+ # ret.update(fields)
+ # fields = ret
return fields
@@ -1103,33 +597,25 @@ class HyperlinkedModelSerializer(ModelSerializer):
"""
# TODO: filter queryset using:
# .using(db).complex_filter(self.rel.limit_choices_to)
- kwargs = {
- 'queryset': related_model._default_manager,
- 'view_name': self._get_default_view_name(related_model),
- 'many': to_many
- }
+ # kwargs = {
+ # 'queryset': related_model._default_manager,
+ # 'view_name': self._get_default_view_name(related_model),
+ # 'many': to_many
+ # }
+ kwargs = {}
if model_field:
kwargs['required'] = not(model_field.null or model_field.blank)
- if model_field.help_text is not None:
- kwargs['help_text'] = model_field.help_text
+ # if model_field.help_text is not None:
+ # kwargs['help_text'] = model_field.help_text
if model_field.verbose_name is not None:
kwargs['label'] = model_field.verbose_name
- if self.opts.lookup_field:
- kwargs['lookup_field'] = self.opts.lookup_field
+ return IntegerField(**kwargs)
+ # if self.opts.lookup_field:
+ # kwargs['lookup_field'] = self.opts.lookup_field
- return self._hyperlink_field_class(**kwargs)
-
- def get_identity(self, data):
- """
- This hook is required for bulk update.
- We need to override the default, to use the url as the identity.
- """
- try:
- return data.get(self.opts.url_field_name, None)
- except AttributeError:
- return None
+ # return self._hyperlink_field_class(**kwargs)
def _get_default_view_name(self, model):
"""
diff --git a/rest_framework/settings.py b/rest_framework/settings.py
index 644751f87..bbe7a56ad 100644
--- a/rest_framework/settings.py
+++ b/rest_framework/settings.py
@@ -111,9 +111,6 @@ DEFAULTS = {
),
'TIME_FORMAT': None,
- # Pending deprecation
- 'FILTER_BACKEND': None,
-
}
@@ -129,7 +126,6 @@ IMPORT_STRINGS = (
'DEFAULT_PAGINATION_SERIALIZER_CLASS',
'DEFAULT_FILTER_BACKENDS',
'EXCEPTION_HANDLER',
- 'FILTER_BACKEND',
'TEST_REQUEST_RENDERER_CLASSES',
'UNAUTHENTICATED_USER',
'UNAUTHENTICATED_TOKEN',
@@ -196,15 +192,9 @@ class APISettings(object):
if val and attr in self.import_strings:
val = perform_import(val, attr)
- self.validate_setting(attr, val)
-
# Cache the result
setattr(self, attr, val)
return val
- def validate_setting(self, attr, val):
- if attr == 'FILTER_BACKEND' and val is not None:
- # Make sure we can initialize the class
- val()
api_settings = APISettings(USER_SETTINGS, DEFAULTS, IMPORT_STRINGS)
diff --git a/rest_framework/utils/encoders.py b/rest_framework/utils/encoders.py
index 00ffdfbae..6a2f61266 100644
--- a/rest_framework/utils/encoders.py
+++ b/rest_framework/utils/encoders.py
@@ -7,7 +7,7 @@ from django.db.models.query import QuerySet
from django.utils.datastructures import SortedDict
from django.utils.functional import Promise
from rest_framework.compat import force_text
-from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata
+# from rest_framework.serializers import DictWithMetadata, SortedDictWithMetadata
import datetime
import decimal
import types
@@ -106,14 +106,14 @@ else:
SortedDict,
yaml.representer.SafeRepresenter.represent_dict
)
- SafeDumper.add_representer(
- DictWithMetadata,
- yaml.representer.SafeRepresenter.represent_dict
- )
- SafeDumper.add_representer(
- SortedDictWithMetadata,
- yaml.representer.SafeRepresenter.represent_dict
- )
+ # SafeDumper.add_representer(
+ # DictWithMetadata,
+ # yaml.representer.SafeRepresenter.represent_dict
+ # )
+ # SafeDumper.add_representer(
+ # SortedDictWithMetadata,
+ # yaml.representer.SafeRepresenter.represent_dict
+ # )
SafeDumper.add_representer(
types.GeneratorType,
yaml.representer.SafeRepresenter.represent_list
diff --git a/rest_framework/utils/html.py b/rest_framework/utils/html.py
new file mode 100644
index 000000000..bf17050df
--- /dev/null
+++ b/rest_framework/utils/html.py
@@ -0,0 +1,86 @@
+"""
+Helpers for dealing with HTML input.
+"""
+
+def is_html_input(dictionary):
+ # MultiDict type datastructures are used to represent HTML form input,
+ # which may have more than one value for each key.
+ return hasattr(dictionary, 'getlist')
+
+
+def parse_html_list(dictionary, prefix=''):
+ """
+ Used to suport list values in HTML forms.
+ Supports lists of primitives and/or dictionaries.
+
+ * List of primitives.
+
+ {
+ '[0]': 'abc',
+ '[1]': 'def',
+ '[2]': 'hij'
+ }
+ -->
+ [
+ 'abc',
+ 'def',
+ 'hij'
+ ]
+
+ * List of dictionaries.
+
+ {
+ '[0]foo': 'abc',
+ '[0]bar': 'def',
+ '[1]foo': 'hij',
+ '[2]bar': 'klm',
+ }
+ -->
+ [
+ {'foo': 'abc', 'bar': 'def'},
+ {'foo': 'hij', 'bar': 'klm'}
+ ]
+ """
+ Dict = type(dictionary)
+ ret = {}
+ regex = re.compile(r'^%s\[([0-9]+)\](.*)$' % re.escape(prefix))
+ for field, value in dictionary.items():
+ match = regex.match(field)
+ if not match:
+ continue
+ index, key = match.groups()
+ index = int(index)
+ if not key:
+ ret[index] = value
+ elif isinstance(ret.get(index), dict):
+ ret[index][key] = value
+ else:
+ ret[index] = Dict({key: value})
+ return [ret[item] for item in sorted(ret.keys())]
+
+
+def parse_html_dict(dictionary, prefix):
+ """
+ Used to support dictionary values in HTML forms.
+
+ {
+ 'profile.username': 'example',
+ 'profile.email': 'example@example.com',
+ }
+ -->
+ {
+ 'profile': {
+ 'username': 'example,
+ 'email': 'example@example.com'
+ }
+ }
+ """
+ ret = {}
+ regex = re.compile(r'^%s\.(.+)$' % re.escape(prefix))
+ for field, value in dictionary.items():
+ match = regex.match(field)
+ if not match:
+ continue
+ key = match.groups()[0]
+ ret[key] = value
+ return ret
diff --git a/tests/serializers.py b/tests/serializers.py
deleted file mode 100644
index be7b37722..000000000
--- a/tests/serializers.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from rest_framework import serializers
-from tests.models import NullableForeignKeySource
-
-
-class NullableFKSourceSerializer(serializers.ModelSerializer):
- class Meta:
- model = NullableForeignKeySource
diff --git a/tests/test_filters.py b/tests/test_filters.py
index 47bffd436..6f24b1abb 100644
--- a/tests/test_filters.py
+++ b/tests/test_filters.py
@@ -16,9 +16,14 @@ factory = APIRequestFactory()
if django_filters:
+ class FilterableItemSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = FilterableItem
+
# Basic filter on a list view.
class FilterFieldsRootView(generics.ListCreateAPIView):
- model = FilterableItem
+ queryset = FilterableItem.objects.all()
+ serializer_class = FilterableItemSerializer
filter_fields = ['decimal', 'date']
filter_backends = (filters.DjangoFilterBackend,)
@@ -33,7 +38,8 @@ if django_filters:
fields = ['text', 'decimal', 'date']
class FilterClassRootView(generics.ListCreateAPIView):
- model = FilterableItem
+ queryset = FilterableItem.objects.all()
+ serializer_class = FilterableItemSerializer
filter_class = SeveralFieldsFilter
filter_backends = (filters.DjangoFilterBackend,)
@@ -46,12 +52,14 @@ if django_filters:
fields = ['text']
class IncorrectlyConfiguredRootView(generics.ListCreateAPIView):
- model = FilterableItem
+ queryset = FilterableItem.objects.all()
+ serializer_class = FilterableItemSerializer
filter_class = MisconfiguredFilter
filter_backends = (filters.DjangoFilterBackend,)
class FilterClassDetailView(generics.RetrieveAPIView):
- model = FilterableItem
+ queryset = FilterableItem.objects.all()
+ serializer_class = FilterableItemSerializer
filter_class = SeveralFieldsFilter
filter_backends = (filters.DjangoFilterBackend,)
@@ -63,15 +71,12 @@ if django_filters:
model = BaseFilterableItem
class BaseFilterableItemFilterRootView(generics.ListCreateAPIView):
- model = FilterableItem
+ queryset = FilterableItem.objects.all()
+ serializer_class = FilterableItemSerializer
filter_class = BaseFilterableItemFilter
filter_backends = (filters.DjangoFilterBackend,)
# Regression test for #814
- class FilterableItemSerializer(serializers.ModelSerializer):
- class Meta:
- model = FilterableItem
-
class FilterFieldsQuerysetView(generics.ListCreateAPIView):
queryset = FilterableItem.objects.all()
serializer_class = FilterableItemSerializer
@@ -323,6 +328,11 @@ class SearchFilterModel(models.Model):
text = models.CharField(max_length=100)
+class SearchFilterSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = SearchFilterModel
+
+
class SearchFilterTests(TestCase):
def setUp(self):
# Sequence of title/text is:
@@ -342,7 +352,8 @@ class SearchFilterTests(TestCase):
def test_search(self):
class SearchListView(generics.ListAPIView):
- model = SearchFilterModel
+ queryset = SearchFilterModel.objects.all()
+ serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('title', 'text')
@@ -359,7 +370,8 @@ class SearchFilterTests(TestCase):
def test_exact_search(self):
class SearchListView(generics.ListAPIView):
- model = SearchFilterModel
+ queryset = SearchFilterModel.objects.all()
+ serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('=title', 'text')
@@ -375,7 +387,8 @@ class SearchFilterTests(TestCase):
def test_startswith_search(self):
class SearchListView(generics.ListAPIView):
- model = SearchFilterModel
+ queryset = SearchFilterModel.objects.all()
+ serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('title', '^text')
@@ -392,7 +405,8 @@ class SearchFilterTests(TestCase):
def test_search_with_nonstandard_search_param(self):
with temporary_setting('SEARCH_PARAM', 'query', module=filters):
class SearchListView(generics.ListAPIView):
- model = SearchFilterModel
+ queryset = SearchFilterModel.objects.all()
+ serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('title', 'text')
@@ -418,6 +432,11 @@ class OrderingFilterRelatedModel(models.Model):
related_name="relateds")
+class OrderingFilterSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = OrdringFilterModel
+
+
class OrderingFilterTests(TestCase):
def setUp(self):
# Sequence of title/text is:
@@ -440,7 +459,8 @@ class OrderingFilterTests(TestCase):
def test_ordering(self):
class OrderingListView(generics.ListAPIView):
- model = OrdringFilterModel
+ queryset = OrdringFilterModel.objects.all()
+ serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,)
ordering = ('title',)
ordering_fields = ('text',)
@@ -459,7 +479,8 @@ class OrderingFilterTests(TestCase):
def test_reverse_ordering(self):
class OrderingListView(generics.ListAPIView):
- model = OrdringFilterModel
+ queryset = OrdringFilterModel.objects.all()
+ serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,)
ordering = ('title',)
ordering_fields = ('text',)
@@ -478,7 +499,8 @@ class OrderingFilterTests(TestCase):
def test_incorrectfield_ordering(self):
class OrderingListView(generics.ListAPIView):
- model = OrdringFilterModel
+ queryset = OrdringFilterModel.objects.all()
+ serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,)
ordering = ('title',)
ordering_fields = ('text',)
@@ -497,7 +519,8 @@ class OrderingFilterTests(TestCase):
def test_default_ordering(self):
class OrderingListView(generics.ListAPIView):
- model = OrdringFilterModel
+ queryset = OrdringFilterModel.objects.all()
+ serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,)
ordering = ('title',)
oredering_fields = ('text',)
@@ -516,7 +539,8 @@ class OrderingFilterTests(TestCase):
def test_default_ordering_using_string(self):
class OrderingListView(generics.ListAPIView):
- model = OrdringFilterModel
+ queryset = OrdringFilterModel.objects.all()
+ serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,)
ordering = 'title'
ordering_fields = ('text',)
@@ -545,7 +569,7 @@ class OrderingFilterTests(TestCase):
new_related.save()
class OrderingListView(generics.ListAPIView):
- model = OrdringFilterModel
+ serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,)
ordering = 'title'
ordering_fields = '__all__'
@@ -567,7 +591,8 @@ class OrderingFilterTests(TestCase):
def test_ordering_with_nonstandard_ordering_param(self):
with temporary_setting('ORDERING_PARAM', 'order', filters):
class OrderingListView(generics.ListAPIView):
- model = OrdringFilterModel
+ queryset = OrdringFilterModel.objects.all()
+ serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,)
ordering = ('title',)
ordering_fields = ('text',)
diff --git a/tests/test_generics.py b/tests/test_generics.py
index e9f5bebdd..f50d53e99 100644
--- a/tests/test_generics.py
+++ b/tests/test_generics.py
@@ -11,18 +11,30 @@ from tests.models import ForeignKeySource, ForeignKeyTarget
factory = APIRequestFactory()
+class BasicSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BasicModel
+
+
+class ForeignKeySerializer(serializers.ModelSerializer):
+ class Meta:
+ model = ForeignKeySource
+
+
class RootView(generics.ListCreateAPIView):
"""
Example description for OPTIONS.
"""
- model = BasicModel
+ queryset = BasicModel.objects.all()
+ serializer_class = BasicSerializer
class InstanceView(generics.RetrieveUpdateDestroyAPIView):
"""
Example description for OPTIONS.
"""
- model = BasicModel
+ queryset = BasicModel.objects.all()
+ serializer_class = BasicSerializer
def get_queryset(self):
queryset = super(InstanceView, self).get_queryset()
@@ -33,7 +45,8 @@ class FKInstanceView(generics.RetrieveUpdateDestroyAPIView):
"""
FK: example description for OPTIONS.
"""
- model = ForeignKeySource
+ queryset = ForeignKeySource.objects.all()
+ serializer_class = ForeignKeySerializer
class SlugSerializer(serializers.ModelSerializer):
@@ -48,7 +61,7 @@ class SlugBasedInstanceView(InstanceView):
"""
A model with a slug-field.
"""
- model = SlugBasedModel
+ queryset = SlugBasedModel.objects.all()
serializer_class = SlugSerializer
lookup_field = 'slug'
@@ -503,7 +516,7 @@ class TestOverriddenGetObject(TestCase):
"""
Example detail view for override of get_object().
"""
- model = BasicModel
+ serializer_class = BasicSerializer
def get_object(self):
pk = int(self.kwargs['pk'])
@@ -573,7 +586,7 @@ class ClassASerializer(serializers.ModelSerializer):
class ExampleView(generics.ListCreateAPIView):
serializer_class = ClassASerializer
- model = ClassA
+ queryset = ClassA.objects.all()
class TestM2MBrowseableAPI(TestCase):
@@ -603,7 +616,7 @@ class TwoFieldModel(models.Model):
class DynamicSerializerView(generics.ListCreateAPIView):
- model = TwoFieldModel
+ queryset = TwoFieldModel.objects.all()
renderer_classes = (renderers.BrowsableAPIRenderer, renderers.JSONRenderer)
def get_serializer_class(self):
@@ -612,8 +625,11 @@ class DynamicSerializerView(generics.ListCreateAPIView):
class Meta:
model = TwoFieldModel
fields = ('field_b',)
- return DynamicSerializer
- return super(DynamicSerializerView, self).get_serializer_class()
+ else:
+ class DynamicSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = TwoFieldModel
+ return DynamicSerializer
class TestFilterBackendAppliedToViews(TestCase):
diff --git a/tests/test_hyperlinkedserializers.py b/tests/test_hyperlinkedserializers.py
index d45485391..0e8c1ed46 100644
--- a/tests/test_hyperlinkedserializers.py
+++ b/tests/test_hyperlinkedserializers.py
@@ -39,59 +39,85 @@ class AlbumSerializer(serializers.ModelSerializer):
fields = ('title', 'url')
+class BasicSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = BasicModel
+
+
+class AnchorSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = Anchor
+
+
+class ManyToManySerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = ManyToManyModel
+
+
+class BlogPostSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BlogPost
+
+
+class OptionalRelationSerializer(serializers.HyperlinkedModelSerializer):
+ class Meta:
+ model = OptionalRelationModel
+
+
class BasicList(generics.ListCreateAPIView):
- model = BasicModel
- model_serializer_class = serializers.HyperlinkedModelSerializer
+ queryset = BasicModel.objects.all()
+ serializer_class = BasicSerializer
class BasicDetail(generics.RetrieveUpdateDestroyAPIView):
- model = BasicModel
- model_serializer_class = serializers.HyperlinkedModelSerializer
+ queryset = BasicModel.objects.all()
+ serializer_class = BasicSerializer
class AnchorDetail(generics.RetrieveAPIView):
- model = Anchor
- model_serializer_class = serializers.HyperlinkedModelSerializer
+ queryset = Anchor.objects.all()
+ serializer_class = AnchorSerializer
class ManyToManyList(generics.ListAPIView):
- model = ManyToManyModel
- model_serializer_class = serializers.HyperlinkedModelSerializer
+ queryset = ManyToManyModel.objects.all()
+ serializer_class = ManyToManySerializer
class ManyToManyDetail(generics.RetrieveAPIView):
- model = ManyToManyModel
- model_serializer_class = serializers.HyperlinkedModelSerializer
+ queryset = ManyToManyModel.objects.all()
+ serializer_class = ManyToManySerializer
class BlogPostCommentListCreate(generics.ListCreateAPIView):
- model = BlogPostComment
+ queryset = BlogPostComment.objects.all()
serializer_class = BlogPostCommentSerializer
class BlogPostCommentDetail(generics.RetrieveAPIView):
- model = BlogPostComment
+ queryset = BlogPostComment.objects.all()
serializer_class = BlogPostCommentSerializer
class BlogPostDetail(generics.RetrieveAPIView):
- model = BlogPost
+ queryset = BlogPost.objects.all()
+ serializer_class = BlogPostSerializer
class PhotoListCreate(generics.ListCreateAPIView):
- model = Photo
- model_serializer_class = PhotoSerializer
+ queryset = Photo.objects.all()
+ serializer_class = PhotoSerializer
class AlbumDetail(generics.RetrieveAPIView):
- model = Album
+ queryset = Album.objects.all()
serializer_class = AlbumSerializer
lookup_field = 'title'
class OptionalRelationDetail(generics.RetrieveUpdateDestroyAPIView):
- model = OptionalRelationModel
- model_serializer_class = serializers.HyperlinkedModelSerializer
+ queryset = OptionalRelationModel.objects.all()
+ serializer_class = OptionalRelationSerializer
urlpatterns = patterns(
diff --git a/tests/test_nullable_fields.py b/tests/test_nullable_fields.py
index 0c133fc2c..8d0c84bb0 100644
--- a/tests/test_nullable_fields.py
+++ b/tests/test_nullable_fields.py
@@ -1,10 +1,19 @@
from django.core.urlresolvers import reverse
from django.conf.urls import patterns, url
+from rest_framework import serializers, generics
from rest_framework.test import APITestCase
from tests.models import NullableForeignKeySource
-from tests.serializers import NullableFKSourceSerializer
-from tests.views import NullableFKSourceDetail
+
+
+class NullableFKSourceSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = NullableForeignKeySource
+
+
+class NullableFKSourceDetail(generics.RetrieveUpdateDestroyAPIView):
+ queryset = NullableForeignKeySource.objects.all()
+ serializer_class = NullableFKSourceSerializer
urlpatterns = patterns(
diff --git a/tests/test_pagination.py b/tests/test_pagination.py
index 80c33e2eb..8f9e0005e 100644
--- a/tests/test_pagination.py
+++ b/tests/test_pagination.py
@@ -4,7 +4,7 @@ from decimal import Decimal
from django.core.paginator import Paginator
from django.test import TestCase
from django.utils import unittest
-from rest_framework import generics, status, pagination, filters, serializers
+from rest_framework import generics, serializers, status, pagination, filters
from rest_framework.compat import django_filters
from rest_framework.test import APIRequestFactory
from .models import BasicModel, FilterableItem
@@ -22,11 +22,22 @@ def split_arguments_from_url(url):
return path, args
+class BasicSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BasicModel
+
+
+class FilterableItemSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = FilterableItem
+
+
class RootView(generics.ListCreateAPIView):
"""
Example description for OPTIONS.
"""
- model = BasicModel
+ queryset = BasicModel.objects.all()
+ serializer_class = BasicSerializer
paginate_by = 10
@@ -34,14 +45,16 @@ class DefaultPageSizeKwargView(generics.ListAPIView):
"""
View for testing default paginate_by_param usage
"""
- model = BasicModel
+ queryset = BasicModel.objects.all()
+ serializer_class = BasicSerializer
class PaginateByParamView(generics.ListAPIView):
"""
View for testing custom paginate_by_param usage
"""
- model = BasicModel
+ queryset = BasicModel.objects.all()
+ serializer_class = BasicSerializer
paginate_by_param = 'page_size'
@@ -49,7 +62,8 @@ class MaxPaginateByView(generics.ListAPIView):
"""
View for testing custom max_paginate_by usage
"""
- model = BasicModel
+ queryset = BasicModel.objects.all()
+ serializer_class = BasicSerializer
paginate_by = 3
max_paginate_by = 5
paginate_by_param = 'page_size'
@@ -140,7 +154,8 @@ class IntegrationTestPaginationAndFiltering(TestCase):
fields = ['text', 'decimal', 'date']
class FilterFieldsRootView(generics.ListCreateAPIView):
- model = FilterableItem
+ queryset = FilterableItem.objects.all()
+ serializer_class = FilterableItemSerializer
paginate_by = 10
filter_class = DecimalFilter
filter_backends = (filters.DjangoFilterBackend,)
@@ -188,7 +203,8 @@ class IntegrationTestPaginationAndFiltering(TestCase):
return queryset.filter(decimal__lt=Decimal(request.GET['decimal']))
class BasicFilterFieldsRootView(generics.ListCreateAPIView):
- model = FilterableItem
+ queryset = FilterableItem.objects.all()
+ serializer_class = FilterableItemSerializer
paginate_by = 10
filter_backends = (DecimalFilterBackend,)
@@ -387,7 +403,7 @@ class TestContextPassedToCustomField(TestCase):
def test_with_pagination(self):
class ListView(generics.ListCreateAPIView):
- model = BasicModel
+ queryset = BasicModel.objects.all()
serializer_class = BasicModelSerializer
paginate_by = 1
diff --git a/tests/test_permissions.py b/tests/test_permissions.py
index 93f8020f3..b90ba4f19 100644
--- a/tests/test_permissions.py
+++ b/tests/test_permissions.py
@@ -3,7 +3,7 @@ from django.contrib.auth.models import User, Permission, Group
from django.db import models
from django.test import TestCase
from django.utils import unittest
-from rest_framework import generics, status, permissions, authentication, HTTP_HEADER_ENCODING
+from rest_framework import generics, serializers, status, permissions, authentication, HTTP_HEADER_ENCODING
from rest_framework.compat import guardian, get_model_name
from rest_framework.filters import DjangoObjectPermissionsFilter
from rest_framework.test import APIRequestFactory
@@ -13,14 +13,21 @@ import base64
factory = APIRequestFactory()
+class BasicSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BasicModel
+
+
class RootView(generics.ListCreateAPIView):
- model = BasicModel
+ queryset = BasicModel.objects.all()
+ serializer_class = BasicSerializer
authentication_classes = [authentication.BasicAuthentication]
permission_classes = [permissions.DjangoModelPermissions]
class InstanceView(generics.RetrieveUpdateDestroyAPIView):
- model = BasicModel
+ queryset = BasicModel.objects.all()
+ serializer_class = BasicSerializer
authentication_classes = [authentication.BasicAuthentication]
permission_classes = [permissions.DjangoModelPermissions]
@@ -167,6 +174,11 @@ class BasicPermModel(models.Model):
)
+class BasicPermSerializer(serializers.ModelSerializer):
+ class Meta:
+ model = BasicPermModel
+
+
# Custom object-level permission, that includes 'view' permissions
class ViewObjectPermissions(permissions.DjangoObjectPermissions):
perms_map = {
@@ -181,7 +193,8 @@ class ViewObjectPermissions(permissions.DjangoObjectPermissions):
class ObjectPermissionInstanceView(generics.RetrieveUpdateDestroyAPIView):
- model = BasicPermModel
+ queryset = BasicPermModel.objects.all()
+ serializer_class = BasicPermSerializer
authentication_classes = [authentication.BasicAuthentication]
permission_classes = [ViewObjectPermissions]
@@ -189,7 +202,8 @@ object_permissions_view = ObjectPermissionInstanceView.as_view()
class ObjectPermissionListView(generics.ListAPIView):
- model = BasicPermModel
+ queryset = BasicPermModel.objects.all()
+ serializer_class = BasicPermSerializer
authentication_classes = [authentication.BasicAuthentication]
permission_classes = [ViewObjectPermissions]
diff --git a/tests/test_response.py b/tests/test_response.py
index 2eff83d3d..004c565c9 100644
--- a/tests/test_response.py
+++ b/tests/test_response.py
@@ -86,14 +86,15 @@ class HTMLView1(APIView):
class HTMLNewModelViewSet(viewsets.ModelViewSet):
- model = BasicModel
+ serializer_class = BasicModelSerializer
+ queryset = BasicModel.objects.all()
class HTMLNewModelView(generics.ListCreateAPIView):
renderer_classes = (BrowsableAPIRenderer,)
permission_classes = []
serializer_class = BasicModelSerializer
- model = BasicModel
+ queryset = BasicModel.objects.all()
new_model_viewset_router = routers.DefaultRouter()
diff --git a/tests/test_validation.py b/tests/test_validation.py
index e13e4078c..f62d9068b 100644
--- a/tests/test_validation.py
+++ b/tests/test_validation.py
@@ -22,7 +22,7 @@ class ValidationModelSerializer(serializers.ModelSerializer):
class UpdateValidationModel(generics.RetrieveUpdateDestroyAPIView):
- model = ValidationModel
+ queryset = ValidationModel.objects.all()
serializer_class = ValidationModelSerializer
@@ -117,7 +117,7 @@ class ValidationMaxValueValidatorModelSerializer(serializers.ModelSerializer):
class UpdateMaxValueValidationModel(generics.RetrieveUpdateDestroyAPIView):
- model = ValidationMaxValueValidatorModel
+ queryset = ValidationMaxValueValidatorModel.objects.all()
serializer_class = ValidationMaxValueValidatorModelSerializer
diff --git a/tests/views.py b/tests/views.py
deleted file mode 100644
index 55935e924..000000000
--- a/tests/views.py
+++ /dev/null
@@ -1,8 +0,0 @@
-from rest_framework import generics
-from .models import NullableForeignKeySource
-from .serializers import NullableFKSourceSerializer
-
-
-class NullableFKSourceDetail(generics.RetrieveUpdateDestroyAPIView):
- model = NullableForeignKeySource
- model_serializer_class = NullableFKSourceSerializer