django-rest-framework/rest_framework/serializers.py
2014-09-19 16:43:13 +01:00

491 lines
18 KiB
Python

"""
Serializers and ModelSerializers are similar to Forms and ModelForms.
Unlike forms, they are not constrained to dealing with HTML output, and
form encoded input.
Serialization in REST framework is a two-phase process:
1. Serializers marshal between complex types like model instances, and
python primitives.
2. The process of marshalling between python primitives and request and
response content is handled by parsers and renderers.
"""
from django.core.exceptions import ImproperlyConfigured, ValidationError
from django.db import models
from django.utils import six
from django.utils.datastructures import SortedDict
from collections import namedtuple
from rest_framework.fields import empty, set_value, Field, SkipField
from rest_framework.settings import api_settings
from rest_framework.utils import html, model_meta, representation
from rest_framework.utils.field_mapping import (
get_url_kwargs, get_field_kwargs,
get_relation_kwargs, get_nested_relation_kwargs,
lookup_class
)
import copy
import inspect
# Note: We do the following so that users of the framework can use this style:
#
# example_field = serializers.CharField(...)
#
# This helps keep the separation between model fields, form fields, and
# serializer fields more explicit.
from rest_framework.relations import * # NOQA
from rest_framework.fields import * # NOQA
FieldResult = namedtuple('FieldResult', ['field', 'value', 'error'])
class BaseSerializer(Field):
"""
The BaseSerializer class provides a minimal class which may be used
for writing custom serializer implementations.
"""
def __init__(self, instance=None, data=None, **kwargs):
super(BaseSerializer, self).__init__(**kwargs)
self.instance = instance
self._initial_data = data
def to_internal_value(self, data):
raise NotImplementedError('`to_internal_value()` must be implemented.')
def to_representation(self, instance):
raise NotImplementedError('`to_representation()` must be implemented.')
def update(self, instance, attrs):
raise NotImplementedError('`update()` must be implemented.')
def create(self, attrs):
raise NotImplementedError('`create()` must be implemented.')
def save(self, extras=None):
attrs = self.validated_data
if extras is not None:
attrs = dict(list(attrs.items()) + list(extras.items()))
if self.instance is not None:
self.update(self.instance, attrs)
else:
self.instance = self.create(attrs)
return self.instance
def is_valid(self, raise_exception=False):
if not hasattr(self, '_validated_data'):
try:
self._validated_data = self.to_internal_value(self._initial_data)
except ValidationError as exc:
self._validated_data = {}
self._errors = exc.message_dict
else:
self._errors = {}
if self._errors and raise_exception:
raise ValidationError(self._errors)
return not bool(self._errors)
@property
def data(self):
if not hasattr(self, '_data'):
if self.instance is not None:
self._data = self.to_representation(self.instance)
elif self._initial_data is not None:
self._data = dict([
(field_name, field.get_value(self._initial_data))
for field_name, field in self.fields.items()
])
else:
self._data = self.get_initial()
return self._data
@property
def errors(self):
if not hasattr(self, '_errors'):
msg = 'You must call `.is_valid()` before accessing `.errors`.'
raise AssertionError(msg)
return self._errors
@property
def validated_data(self):
if not hasattr(self, '_validated_data'):
msg = 'You must call `.is_valid()` before accessing `.validated_data`.'
raise AssertionError(msg)
return self._validated_data
class SerializerMetaclass(type):
"""
This metaclass sets a dictionary named `base_fields` on the class.
Any instances of `Field` included as attributes on either the class
or on any of its superclasses will be include in the
`base_fields` dictionary.
"""
@classmethod
def _get_declared_fields(cls, bases, attrs):
fields = [(field_name, attrs.pop(field_name))
for field_name, obj in list(attrs.items())
if isinstance(obj, Field)]
fields.sort(key=lambda x: x[1]._creation_counter)
# If this class is subclassing another Serializer, add that Serializer's
# fields. Note that we loop over the bases in *reverse*. This is necessary
# in order to maintain the correct order of fields.
for base in bases[::-1]:
if hasattr(base, '_declared_fields'):
fields = list(base._declared_fields.items()) + fields
return SortedDict(fields)
def __new__(cls, name, bases, attrs):
attrs['_declared_fields'] = cls._get_declared_fields(bases, attrs)
return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs)
@six.add_metaclass(SerializerMetaclass)
class Serializer(BaseSerializer):
def __init__(self, *args, **kwargs):
self.context = kwargs.pop('context', {})
kwargs.pop('partial', None)
kwargs.pop('many', None)
super(Serializer, self).__init__(*args, **kwargs)
# Every new serializer is created with a clone of the field instances.
# This allows users to dynamically modify the fields on a serializer
# instance without affecting every other serializer class.
self.fields = self._get_base_fields()
# Setup all the child fields, to provide them with the current context.
for field_name, field in self.fields.items():
field.bind(field_name, self, self)
def __new__(cls, *args, **kwargs):
# We override this method in order to automagically create
# `ListSerializer` classes instead when `many=True` is set.
if kwargs.pop('many', False):
kwargs['child'] = cls()
return ListSerializer(*args, **kwargs)
return super(Serializer, cls).__new__(cls, *args, **kwargs)
def _get_base_fields(self):
return copy.deepcopy(self._declared_fields)
def bind(self, field_name, parent, root):
# If the serializer is used as a field then when it becomes bound
# it also needs to bind all its child fields.
super(Serializer, self).bind(field_name, parent, root)
for field_name, field in self.fields.items():
field.bind(field_name, self, root)
def get_initial(self):
return dict([
(field.field_name, field.get_initial())
for field in self.fields.values()
])
def get_value(self, dictionary):
# We override the default field access in order to support
# nested HTML forms.
if html.is_html_input(dictionary):
return html.parse_html_dict(dictionary, prefix=self.field_name)
return dictionary.get(self.field_name, empty)
def to_internal_value(self, data):
"""
Dict of native values <- Dict of primitive datatypes.
"""
if not isinstance(data, dict):
raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: ['Invalid data']
})
ret = {}
errors = {}
fields = [field for field in self.fields.values() if not field.read_only]
for field in fields:
validate_method = getattr(self, 'validate_' + field.field_name, None)
primitive_value = field.get_value(data)
try:
validated_value = field.run_validation(primitive_value)
if validate_method is not None:
validated_value = validate_method(validated_value)
except ValidationError as exc:
errors[field.field_name] = exc.messages
except SkipField:
pass
else:
set_value(ret, field.source_attrs, validated_value)
if errors:
raise ValidationError(errors)
try:
return self.validate(ret)
except ValidationError as exc:
raise ValidationError({
api_settings.NON_FIELD_ERRORS_KEY: exc.messages
})
def to_representation(self, instance):
"""
Object instance -> Dict of primitive datatypes.
"""
ret = SortedDict()
fields = [field for field in self.fields.values() if not field.write_only]
for field in fields:
native_value = field.get_attribute(instance)
ret[field.field_name] = field.to_representation(native_value)
return ret
def validate(self, attrs):
return attrs
def __iter__(self):
errors = self.errors if hasattr(self, '_errors') else {}
for field in self.fields.values():
value = self.data.get(field.field_name) if self.data else None
error = errors.get(field.field_name)
yield FieldResult(field, value, error)
def __repr__(self):
return representation.serializer_repr(self, indent=1)
class ListSerializer(BaseSerializer):
child = None
initial = []
def __init__(self, *args, **kwargs):
self.child = kwargs.pop('child', copy.deepcopy(self.child))
assert self.child is not None, '`child` is a required argument.'
assert not inspect.isclass(self.child), '`child` has not been instantiated.'
self.context = kwargs.pop('context', {})
kwargs.pop('partial', None)
super(ListSerializer, self).__init__(*args, **kwargs)
self.child.bind('', self, self)
def bind(self, field_name, parent, root):
# If the list is used as a field then it needs to provide
# the current context to the child serializer.
super(ListSerializer, self).bind(field_name, parent, root)
self.child.bind(field_name, self, root)
def get_value(self, dictionary):
# We override the default field access in order to support
# lists in HTML forms.
if is_html_input(dictionary):
return html.parse_html_list(dictionary, prefix=self.field_name)
return dictionary.get(self.field_name, empty)
def to_internal_value(self, data):
"""
List of dicts of native values <- List of dicts of primitive datatypes.
"""
if html.is_html_input(data):
data = html.parse_html_list(data)
return [self.child.run_validation(item) for item in data]
def to_representation(self, data):
"""
List of object instances -> List of dicts of primitive datatypes.
"""
return [self.child.to_representation(item) for item in data]
def create(self, attrs_list):
return [self.child.create(attrs) for attrs in attrs_list]
def save(self):
if self.instance is not None:
self.update(self.instance, self.validated_data)
self.instance = self.create(self.validated_data)
return self.instance
def __repr__(self):
return representation.list_repr(self, indent=1)
class ModelSerializer(Serializer):
_field_mapping = {
models.AutoField: IntegerField,
models.BigIntegerField: IntegerField,
models.BooleanField: BooleanField,
models.CharField: CharField,
models.CommaSeparatedIntegerField: CharField,
models.DateField: DateField,
models.DateTimeField: DateTimeField,
models.DecimalField: DecimalField,
models.EmailField: EmailField,
models.Field: ModelField,
models.FileField: FileField,
models.FloatField: FloatField,
models.ImageField: ImageField,
models.IntegerField: IntegerField,
models.NullBooleanField: BooleanField,
models.PositiveIntegerField: IntegerField,
models.PositiveSmallIntegerField: IntegerField,
models.SlugField: SlugField,
models.SmallIntegerField: IntegerField,
models.TextField: CharField,
models.TimeField: TimeField,
models.URLField: URLField,
}
_related_class = PrimaryKeyRelatedField
def create(self, attrs):
ModelClass = self.Meta.model
# Remove many-to-many relationships from attrs.
# They are not valid arguments to the default `.create()` method,
# as they require that the instance has already been saved.
info = model_meta.get_field_info(ModelClass)
many_to_many = {}
for field_name, relation_info in info.relations.items():
if relation_info.to_many and (field_name in attrs):
many_to_many[field_name] = attrs.pop(field_name)
instance = ModelClass.objects.create(**attrs)
# Save many-to-many relationships after the instance is created.
if many_to_many:
for field_name, value in many_to_many.items():
setattr(instance, field_name, value)
return instance
def update(self, obj, attrs):
for attr, value in attrs.items():
setattr(obj, attr, value)
obj.save()
def _get_base_fields(self):
declared_fields = copy.deepcopy(self._declared_fields)
ret = SortedDict()
model = getattr(self.Meta, 'model')
fields = getattr(self.Meta, 'fields', None)
depth = getattr(self.Meta, 'depth', 0)
extra_kwargs = getattr(self.Meta, 'extra_kwargs', {})
# Retrieve metadata about fields & relationships on the model class.
info = model_meta.get_field_info(model)
# Use the default set of fields if none is supplied explicitly.
if fields is None:
fields = self._get_default_field_names(declared_fields, info)
for field_name in fields:
if field_name in declared_fields:
# Field is explicitly declared on the class, use that.
ret[field_name] = declared_fields[field_name]
continue
elif field_name == api_settings.URL_FIELD_NAME:
# Create the URL field.
field_cls = HyperlinkedIdentityField
kwargs = get_url_kwargs(model)
elif field_name in info.fields_and_pk:
# Create regular model fields.
model_field = info.fields_and_pk[field_name]
field_cls = lookup_class(self._field_mapping, model_field)
kwargs = get_field_kwargs(field_name, model_field)
if 'choices' in kwargs:
# Fields with choices get coerced into `ChoiceField`
# instead of using their regular typed field.
field_cls = ChoiceField
if not issubclass(field_cls, ModelField):
# `model_field` is only valid for the fallback case of
# `ModelField`, which is used when no other typed field
# matched to the model field.
kwargs.pop('model_field', None)
elif field_name in info.relations:
# Create forward and reverse relationships.
relation_info = info.relations[field_name]
if depth:
field_cls = self._get_nested_class(depth, relation_info)
kwargs = get_nested_relation_kwargs(relation_info)
else:
field_cls = self._related_class
kwargs = get_relation_kwargs(field_name, relation_info)
# `view_name` is only valid for hyperlinked relationships.
if not issubclass(field_cls, HyperlinkedRelatedField):
kwargs.pop('view_name', None)
elif hasattr(model, field_name):
# Create a read only field for model methods and properties.
field_cls = ReadOnlyField
kwargs = {}
else:
raise ImproperlyConfigured(
'Field name `%s` is not valid for model `%s`.' %
(field_name, model.__class__.__name__)
)
# Check that any fields declared on the class are
# also explicity included in `Meta.fields`.
missing_fields = set(declared_fields.keys()) - set(fields)
if missing_fields:
missing_field = list(missing_fields)[0]
raise ImproperlyConfigured(
'Field `%s` has been declared on serializer `%s`, but '
'is missing from `Meta.fields`.' %
(missing_field, self.__class__.__name__)
)
# Populate any kwargs defined in `Meta.extra_kwargs`
kwargs.update(extra_kwargs.get(field_name, {}))
# Create the serializer field.
ret[field_name] = field_cls(**kwargs)
return ret
def _get_default_field_names(self, declared_fields, model_info):
return (
[model_info.pk.name] +
list(declared_fields.keys()) +
list(model_info.fields.keys()) +
list(model_info.forward_relations.keys())
)
def _get_nested_class(self, nested_depth, relation_info):
class NestedSerializer(ModelSerializer):
class Meta:
model = relation_info.related
depth = nested_depth
return NestedSerializer
class HyperlinkedModelSerializer(ModelSerializer):
_related_class = HyperlinkedRelatedField
def _get_default_field_names(self, declared_fields, model_info):
return (
[api_settings.URL_FIELD_NAME] +
list(declared_fields.keys()) +
list(model_info.fields.keys()) +
list(model_info.forward_relations.keys())
)
def _get_nested_class(self, nested_depth, relation_info):
class NestedSerializer(HyperlinkedModelSerializer):
class Meta:
model = relation_info.related
depth = nested_depth
return NestedSerializer