mirror of
https://github.com/encode/django-rest-framework.git
synced 2024-11-30 05:24:00 +03:00
491 lines
18 KiB
Python
491 lines
18 KiB
Python
"""
|
|
Serializers and ModelSerializers are similar to Forms and ModelForms.
|
|
Unlike forms, they are not constrained to dealing with HTML output, and
|
|
form encoded input.
|
|
|
|
Serialization in REST framework is a two-phase process:
|
|
|
|
1. Serializers marshal between complex types like model instances, and
|
|
python primitives.
|
|
2. The process of marshalling between python primitives and request and
|
|
response content is handled by parsers and renderers.
|
|
"""
|
|
from django.core.exceptions import ImproperlyConfigured, ValidationError
|
|
from django.db import models
|
|
from django.utils import six
|
|
from django.utils.datastructures import SortedDict
|
|
from collections import namedtuple
|
|
from rest_framework.fields import empty, set_value, Field, SkipField
|
|
from rest_framework.settings import api_settings
|
|
from rest_framework.utils import html, model_meta, representation
|
|
from rest_framework.utils.field_mapping import (
|
|
get_url_kwargs, get_field_kwargs,
|
|
get_relation_kwargs, get_nested_relation_kwargs,
|
|
lookup_class
|
|
)
|
|
import copy
|
|
import inspect
|
|
|
|
# Note: We do the following so that users of the framework can use this style:
|
|
#
|
|
# example_field = serializers.CharField(...)
|
|
#
|
|
# This helps keep the separation between model fields, form fields, and
|
|
# serializer fields more explicit.
|
|
|
|
from rest_framework.relations import * # NOQA
|
|
from rest_framework.fields import * # NOQA
|
|
|
|
|
|
FieldResult = namedtuple('FieldResult', ['field', 'value', 'error'])
|
|
|
|
|
|
class BaseSerializer(Field):
|
|
"""
|
|
The BaseSerializer class provides a minimal class which may be used
|
|
for writing custom serializer implementations.
|
|
"""
|
|
|
|
def __init__(self, instance=None, data=None, **kwargs):
|
|
super(BaseSerializer, self).__init__(**kwargs)
|
|
self.instance = instance
|
|
self._initial_data = data
|
|
|
|
def to_internal_value(self, data):
|
|
raise NotImplementedError('`to_internal_value()` must be implemented.')
|
|
|
|
def to_representation(self, instance):
|
|
raise NotImplementedError('`to_representation()` must be implemented.')
|
|
|
|
def update(self, instance, attrs):
|
|
raise NotImplementedError('`update()` must be implemented.')
|
|
|
|
def create(self, attrs):
|
|
raise NotImplementedError('`create()` must be implemented.')
|
|
|
|
def save(self, extras=None):
|
|
attrs = self.validated_data
|
|
if extras is not None:
|
|
attrs = dict(list(attrs.items()) + list(extras.items()))
|
|
|
|
if self.instance is not None:
|
|
self.update(self.instance, attrs)
|
|
else:
|
|
self.instance = self.create(attrs)
|
|
|
|
return self.instance
|
|
|
|
def is_valid(self, raise_exception=False):
|
|
if not hasattr(self, '_validated_data'):
|
|
try:
|
|
self._validated_data = self.to_internal_value(self._initial_data)
|
|
except ValidationError as exc:
|
|
self._validated_data = {}
|
|
self._errors = exc.message_dict
|
|
else:
|
|
self._errors = {}
|
|
|
|
if self._errors and raise_exception:
|
|
raise ValidationError(self._errors)
|
|
|
|
return not bool(self._errors)
|
|
|
|
@property
|
|
def data(self):
|
|
if not hasattr(self, '_data'):
|
|
if self.instance is not None:
|
|
self._data = self.to_representation(self.instance)
|
|
elif self._initial_data is not None:
|
|
self._data = dict([
|
|
(field_name, field.get_value(self._initial_data))
|
|
for field_name, field in self.fields.items()
|
|
])
|
|
else:
|
|
self._data = self.get_initial()
|
|
return self._data
|
|
|
|
@property
|
|
def errors(self):
|
|
if not hasattr(self, '_errors'):
|
|
msg = 'You must call `.is_valid()` before accessing `.errors`.'
|
|
raise AssertionError(msg)
|
|
return self._errors
|
|
|
|
@property
|
|
def validated_data(self):
|
|
if not hasattr(self, '_validated_data'):
|
|
msg = 'You must call `.is_valid()` before accessing `.validated_data`.'
|
|
raise AssertionError(msg)
|
|
return self._validated_data
|
|
|
|
|
|
class SerializerMetaclass(type):
|
|
"""
|
|
This metaclass sets a dictionary named `base_fields` on the class.
|
|
|
|
Any instances of `Field` included as attributes on either the class
|
|
or on any of its superclasses will be include in the
|
|
`base_fields` dictionary.
|
|
"""
|
|
|
|
@classmethod
|
|
def _get_declared_fields(cls, bases, attrs):
|
|
fields = [(field_name, attrs.pop(field_name))
|
|
for field_name, obj in list(attrs.items())
|
|
if isinstance(obj, Field)]
|
|
fields.sort(key=lambda x: x[1]._creation_counter)
|
|
|
|
# If this class is subclassing another Serializer, add that Serializer's
|
|
# fields. Note that we loop over the bases in *reverse*. This is necessary
|
|
# in order to maintain the correct order of fields.
|
|
for base in bases[::-1]:
|
|
if hasattr(base, '_declared_fields'):
|
|
fields = list(base._declared_fields.items()) + fields
|
|
|
|
return SortedDict(fields)
|
|
|
|
def __new__(cls, name, bases, attrs):
|
|
attrs['_declared_fields'] = cls._get_declared_fields(bases, attrs)
|
|
return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs)
|
|
|
|
|
|
@six.add_metaclass(SerializerMetaclass)
|
|
class Serializer(BaseSerializer):
|
|
def __init__(self, *args, **kwargs):
|
|
self.context = kwargs.pop('context', {})
|
|
kwargs.pop('partial', None)
|
|
kwargs.pop('many', None)
|
|
|
|
super(Serializer, self).__init__(*args, **kwargs)
|
|
|
|
# Every new serializer is created with a clone of the field instances.
|
|
# This allows users to dynamically modify the fields on a serializer
|
|
# instance without affecting every other serializer class.
|
|
self.fields = self._get_base_fields()
|
|
|
|
# Setup all the child fields, to provide them with the current context.
|
|
for field_name, field in self.fields.items():
|
|
field.bind(field_name, self, self)
|
|
|
|
def __new__(cls, *args, **kwargs):
|
|
# We override this method in order to automagically create
|
|
# `ListSerializer` classes instead when `many=True` is set.
|
|
if kwargs.pop('many', False):
|
|
kwargs['child'] = cls()
|
|
return ListSerializer(*args, **kwargs)
|
|
return super(Serializer, cls).__new__(cls, *args, **kwargs)
|
|
|
|
def _get_base_fields(self):
|
|
return copy.deepcopy(self._declared_fields)
|
|
|
|
def bind(self, field_name, parent, root):
|
|
# If the serializer is used as a field then when it becomes bound
|
|
# it also needs to bind all its child fields.
|
|
super(Serializer, self).bind(field_name, parent, root)
|
|
for field_name, field in self.fields.items():
|
|
field.bind(field_name, self, root)
|
|
|
|
def get_initial(self):
|
|
return dict([
|
|
(field.field_name, field.get_initial())
|
|
for field in self.fields.values()
|
|
])
|
|
|
|
def get_value(self, dictionary):
|
|
# We override the default field access in order to support
|
|
# nested HTML forms.
|
|
if html.is_html_input(dictionary):
|
|
return html.parse_html_dict(dictionary, prefix=self.field_name)
|
|
return dictionary.get(self.field_name, empty)
|
|
|
|
def to_internal_value(self, data):
|
|
"""
|
|
Dict of native values <- Dict of primitive datatypes.
|
|
"""
|
|
if not isinstance(data, dict):
|
|
raise ValidationError({
|
|
api_settings.NON_FIELD_ERRORS_KEY: ['Invalid data']
|
|
})
|
|
|
|
ret = {}
|
|
errors = {}
|
|
fields = [field for field in self.fields.values() if not field.read_only]
|
|
|
|
for field in fields:
|
|
validate_method = getattr(self, 'validate_' + field.field_name, None)
|
|
primitive_value = field.get_value(data)
|
|
try:
|
|
validated_value = field.run_validation(primitive_value)
|
|
if validate_method is not None:
|
|
validated_value = validate_method(validated_value)
|
|
except ValidationError as exc:
|
|
errors[field.field_name] = exc.messages
|
|
except SkipField:
|
|
pass
|
|
else:
|
|
set_value(ret, field.source_attrs, validated_value)
|
|
|
|
if errors:
|
|
raise ValidationError(errors)
|
|
|
|
try:
|
|
return self.validate(ret)
|
|
except ValidationError as exc:
|
|
raise ValidationError({
|
|
api_settings.NON_FIELD_ERRORS_KEY: exc.messages
|
|
})
|
|
|
|
def to_representation(self, instance):
|
|
"""
|
|
Object instance -> Dict of primitive datatypes.
|
|
"""
|
|
ret = SortedDict()
|
|
fields = [field for field in self.fields.values() if not field.write_only]
|
|
|
|
for field in fields:
|
|
native_value = field.get_attribute(instance)
|
|
ret[field.field_name] = field.to_representation(native_value)
|
|
|
|
return ret
|
|
|
|
def validate(self, attrs):
|
|
return attrs
|
|
|
|
def __iter__(self):
|
|
errors = self.errors if hasattr(self, '_errors') else {}
|
|
for field in self.fields.values():
|
|
value = self.data.get(field.field_name) if self.data else None
|
|
error = errors.get(field.field_name)
|
|
yield FieldResult(field, value, error)
|
|
|
|
def __repr__(self):
|
|
return representation.serializer_repr(self, indent=1)
|
|
|
|
|
|
class ListSerializer(BaseSerializer):
|
|
child = None
|
|
initial = []
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
self.child = kwargs.pop('child', copy.deepcopy(self.child))
|
|
assert self.child is not None, '`child` is a required argument.'
|
|
assert not inspect.isclass(self.child), '`child` has not been instantiated.'
|
|
self.context = kwargs.pop('context', {})
|
|
kwargs.pop('partial', None)
|
|
|
|
super(ListSerializer, self).__init__(*args, **kwargs)
|
|
self.child.bind('', self, self)
|
|
|
|
def bind(self, field_name, parent, root):
|
|
# If the list is used as a field then it needs to provide
|
|
# the current context to the child serializer.
|
|
super(ListSerializer, self).bind(field_name, parent, root)
|
|
self.child.bind(field_name, self, root)
|
|
|
|
def get_value(self, dictionary):
|
|
# We override the default field access in order to support
|
|
# lists in HTML forms.
|
|
if is_html_input(dictionary):
|
|
return html.parse_html_list(dictionary, prefix=self.field_name)
|
|
return dictionary.get(self.field_name, empty)
|
|
|
|
def to_internal_value(self, data):
|
|
"""
|
|
List of dicts of native values <- List of dicts of primitive datatypes.
|
|
"""
|
|
if html.is_html_input(data):
|
|
data = html.parse_html_list(data)
|
|
|
|
return [self.child.run_validation(item) for item in data]
|
|
|
|
def to_representation(self, data):
|
|
"""
|
|
List of object instances -> List of dicts of primitive datatypes.
|
|
"""
|
|
return [self.child.to_representation(item) for item in data]
|
|
|
|
def create(self, attrs_list):
|
|
return [self.child.create(attrs) for attrs in attrs_list]
|
|
|
|
def save(self):
|
|
if self.instance is not None:
|
|
self.update(self.instance, self.validated_data)
|
|
self.instance = self.create(self.validated_data)
|
|
return self.instance
|
|
|
|
def __repr__(self):
|
|
return representation.list_repr(self, indent=1)
|
|
|
|
|
|
class ModelSerializer(Serializer):
|
|
_field_mapping = {
|
|
models.AutoField: IntegerField,
|
|
models.BigIntegerField: IntegerField,
|
|
models.BooleanField: BooleanField,
|
|
models.CharField: CharField,
|
|
models.CommaSeparatedIntegerField: CharField,
|
|
models.DateField: DateField,
|
|
models.DateTimeField: DateTimeField,
|
|
models.DecimalField: DecimalField,
|
|
models.EmailField: EmailField,
|
|
models.Field: ModelField,
|
|
models.FileField: FileField,
|
|
models.FloatField: FloatField,
|
|
models.ImageField: ImageField,
|
|
models.IntegerField: IntegerField,
|
|
models.NullBooleanField: BooleanField,
|
|
models.PositiveIntegerField: IntegerField,
|
|
models.PositiveSmallIntegerField: IntegerField,
|
|
models.SlugField: SlugField,
|
|
models.SmallIntegerField: IntegerField,
|
|
models.TextField: CharField,
|
|
models.TimeField: TimeField,
|
|
models.URLField: URLField,
|
|
}
|
|
_related_class = PrimaryKeyRelatedField
|
|
|
|
def create(self, attrs):
|
|
ModelClass = self.Meta.model
|
|
|
|
# Remove many-to-many relationships from attrs.
|
|
# They are not valid arguments to the default `.create()` method,
|
|
# as they require that the instance has already been saved.
|
|
info = model_meta.get_field_info(ModelClass)
|
|
many_to_many = {}
|
|
for field_name, relation_info in info.relations.items():
|
|
if relation_info.to_many and (field_name in attrs):
|
|
many_to_many[field_name] = attrs.pop(field_name)
|
|
|
|
instance = ModelClass.objects.create(**attrs)
|
|
|
|
# Save many-to-many relationships after the instance is created.
|
|
if many_to_many:
|
|
for field_name, value in many_to_many.items():
|
|
setattr(instance, field_name, value)
|
|
|
|
return instance
|
|
|
|
def update(self, obj, attrs):
|
|
for attr, value in attrs.items():
|
|
setattr(obj, attr, value)
|
|
obj.save()
|
|
|
|
def _get_base_fields(self):
|
|
declared_fields = copy.deepcopy(self._declared_fields)
|
|
|
|
ret = SortedDict()
|
|
model = getattr(self.Meta, 'model')
|
|
fields = getattr(self.Meta, 'fields', None)
|
|
depth = getattr(self.Meta, 'depth', 0)
|
|
extra_kwargs = getattr(self.Meta, 'extra_kwargs', {})
|
|
|
|
# Retrieve metadata about fields & relationships on the model class.
|
|
info = model_meta.get_field_info(model)
|
|
|
|
# Use the default set of fields if none is supplied explicitly.
|
|
if fields is None:
|
|
fields = self._get_default_field_names(declared_fields, info)
|
|
|
|
for field_name in fields:
|
|
if field_name in declared_fields:
|
|
# Field is explicitly declared on the class, use that.
|
|
ret[field_name] = declared_fields[field_name]
|
|
continue
|
|
|
|
elif field_name == api_settings.URL_FIELD_NAME:
|
|
# Create the URL field.
|
|
field_cls = HyperlinkedIdentityField
|
|
kwargs = get_url_kwargs(model)
|
|
|
|
elif field_name in info.fields_and_pk:
|
|
# Create regular model fields.
|
|
model_field = info.fields_and_pk[field_name]
|
|
field_cls = lookup_class(self._field_mapping, model_field)
|
|
kwargs = get_field_kwargs(field_name, model_field)
|
|
if 'choices' in kwargs:
|
|
# Fields with choices get coerced into `ChoiceField`
|
|
# instead of using their regular typed field.
|
|
field_cls = ChoiceField
|
|
if not issubclass(field_cls, ModelField):
|
|
# `model_field` is only valid for the fallback case of
|
|
# `ModelField`, which is used when no other typed field
|
|
# matched to the model field.
|
|
kwargs.pop('model_field', None)
|
|
|
|
elif field_name in info.relations:
|
|
# Create forward and reverse relationships.
|
|
relation_info = info.relations[field_name]
|
|
if depth:
|
|
field_cls = self._get_nested_class(depth, relation_info)
|
|
kwargs = get_nested_relation_kwargs(relation_info)
|
|
else:
|
|
field_cls = self._related_class
|
|
kwargs = get_relation_kwargs(field_name, relation_info)
|
|
# `view_name` is only valid for hyperlinked relationships.
|
|
if not issubclass(field_cls, HyperlinkedRelatedField):
|
|
kwargs.pop('view_name', None)
|
|
|
|
elif hasattr(model, field_name):
|
|
# Create a read only field for model methods and properties.
|
|
field_cls = ReadOnlyField
|
|
kwargs = {}
|
|
|
|
else:
|
|
raise ImproperlyConfigured(
|
|
'Field name `%s` is not valid for model `%s`.' %
|
|
(field_name, model.__class__.__name__)
|
|
)
|
|
|
|
# Check that any fields declared on the class are
|
|
# also explicity included in `Meta.fields`.
|
|
missing_fields = set(declared_fields.keys()) - set(fields)
|
|
if missing_fields:
|
|
missing_field = list(missing_fields)[0]
|
|
raise ImproperlyConfigured(
|
|
'Field `%s` has been declared on serializer `%s`, but '
|
|
'is missing from `Meta.fields`.' %
|
|
(missing_field, self.__class__.__name__)
|
|
)
|
|
|
|
# Populate any kwargs defined in `Meta.extra_kwargs`
|
|
kwargs.update(extra_kwargs.get(field_name, {}))
|
|
|
|
# Create the serializer field.
|
|
ret[field_name] = field_cls(**kwargs)
|
|
|
|
return ret
|
|
|
|
def _get_default_field_names(self, declared_fields, model_info):
|
|
return (
|
|
[model_info.pk.name] +
|
|
list(declared_fields.keys()) +
|
|
list(model_info.fields.keys()) +
|
|
list(model_info.forward_relations.keys())
|
|
)
|
|
|
|
def _get_nested_class(self, nested_depth, relation_info):
|
|
class NestedSerializer(ModelSerializer):
|
|
class Meta:
|
|
model = relation_info.related
|
|
depth = nested_depth
|
|
return NestedSerializer
|
|
|
|
|
|
class HyperlinkedModelSerializer(ModelSerializer):
|
|
_related_class = HyperlinkedRelatedField
|
|
|
|
def _get_default_field_names(self, declared_fields, model_info):
|
|
return (
|
|
[api_settings.URL_FIELD_NAME] +
|
|
list(declared_fields.keys()) +
|
|
list(model_info.fields.keys()) +
|
|
list(model_info.forward_relations.keys())
|
|
)
|
|
|
|
def _get_nested_class(self, nested_depth, relation_info):
|
|
class NestedSerializer(HyperlinkedModelSerializer):
|
|
class Meta:
|
|
model = relation_info.related
|
|
depth = nested_depth
|
|
return NestedSerializer
|