django-rest-framework/rest_framework/serializers.py

641 lines
22 KiB
Python
Raw Normal View History

2013-04-25 15:47:34 +04:00
"""
Serializers and ModelSerializers are similar to Forms and ModelForms.
Unlike forms, they are not constrained to dealing with HTML output, and
form encoded input.
Serialization in REST framework is a two-phase process:
1. Serializers marshal between complex types like model instances, and
python primitives.
2. The process of marshalling between python primitives and request and
2013-04-25 15:47:34 +04:00
response content is handled by parsers and renderers.
"""
from django.db import models
from django.utils import six
2014-08-29 19:46:26 +04:00
from collections import namedtuple, OrderedDict
from rest_framework.fields import empty, set_value, Field, SkipField, ValidationError
from rest_framework.settings import api_settings
2014-08-29 19:46:26 +04:00
from rest_framework.utils import html
import copy
import inspect
2012-11-05 14:56:30 +04:00
# Note: We do the following so that users of the framework can use this style:
#
# example_field = serializers.CharField(...)
#
2013-05-28 18:09:23 +04:00
# This helps keep the separation between model fields, form fields, and
2012-11-05 14:56:30 +04:00
# serializer fields more explicit.
2014-05-17 02:05:33 +04:00
from rest_framework.relations import * # NOQA
from rest_framework.fields import * # NOQA
2014-08-29 19:46:26 +04:00
FieldResult = namedtuple('FieldResult', ['field', 'value', 'error'])
2014-08-29 19:46:26 +04:00
class BaseSerializer(Field):
def __init__(self, instance=None, data=None, **kwargs):
super(BaseSerializer, self).__init__(**kwargs)
self.instance = instance
self._initial_data = data
2014-08-29 19:46:26 +04:00
def to_native(self, data):
raise NotImplementedError()
2014-08-29 19:46:26 +04:00
def to_primative(self, instance):
raise NotImplementedError()
2014-08-29 19:46:26 +04:00
def update(self, instance):
raise NotImplementedError()
2014-08-29 19:46:26 +04:00
def create(self):
raise NotImplementedError()
2014-08-29 19:46:26 +04:00
def save(self, extras=None):
if extras is not None:
self._validated_data.update(extras)
2014-08-29 19:46:26 +04:00
if self.instance is not None:
self.update(self.instance)
else:
2014-08-29 19:46:26 +04:00
self.instance = self.create()
2014-08-29 19:46:26 +04:00
return self.instance
2014-08-29 19:46:26 +04:00
def is_valid(self):
try:
self._validated_data = self.to_native(self._initial_data)
except ValidationError as exc:
self._validated_data = {}
self._errors = exc.args[0]
return False
self._errors = {}
return True
2014-08-29 19:46:26 +04:00
@property
def data(self):
if not hasattr(self, '_data'):
if self.instance is not None:
self._data = self.to_primative(self.instance)
elif self._initial_data is not None:
self._data = {
field_name: field.get_value(self._initial_data)
for field_name, field in self.fields.items()
}
else:
self._data = self.get_initial()
return self._data
2014-08-29 19:46:26 +04:00
@property
def errors(self):
if not hasattr(self, '_errors'):
msg = 'You must call `.is_valid()` before accessing `.errors`.'
raise AssertionError(msg)
return self._errors
2014-08-29 19:46:26 +04:00
@property
def validated_data(self):
if not hasattr(self, '_validated_data'):
msg = 'You must call `.is_valid()` before accessing `.validated_data`.'
raise AssertionError(msg)
return self._validated_data
2014-08-29 19:46:26 +04:00
class SerializerMetaclass(type):
"""
2014-08-29 19:46:26 +04:00
This metaclass sets a dictionary named `base_fields` on the class.
2014-08-29 19:46:26 +04:00
Any fields included as attributes on either the class or it's superclasses
will be include in the `base_fields` dictionary.
"""
2014-08-29 19:46:26 +04:00
@classmethod
def _get_fields(cls, bases, attrs):
fields = [(field_name, attrs.pop(field_name))
for field_name, obj in list(attrs.items())
if isinstance(obj, Field)]
fields.sort(key=lambda x: x[1]._creation_counter)
2014-08-29 19:46:26 +04:00
# If this class is subclassing another Serializer, add that Serializer's
# fields. Note that we loop over the bases in *reverse*. This is necessary
# in order to maintain the correct order of fields.
for base in bases[::-1]:
if hasattr(base, 'base_fields'):
fields = list(base.base_fields.items()) + fields
2014-08-29 19:46:26 +04:00
return OrderedDict(fields)
def __new__(cls, name, bases, attrs):
2014-08-29 19:46:26 +04:00
attrs['base_fields'] = cls._get_fields(bases, attrs)
return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs)
2014-08-29 19:46:26 +04:00
@six.add_metaclass(SerializerMetaclass)
class Serializer(BaseSerializer):
2014-08-29 19:46:26 +04:00
def __new__(cls, *args, **kwargs):
many = kwargs.pop('many', False)
if many:
class DynamicListSerializer(ListSerializer):
child = cls()
return DynamicListSerializer(*args, **kwargs)
return super(Serializer, cls).__new__(cls)
2014-08-29 19:46:26 +04:00
def __init__(self, *args, **kwargs):
2014-09-02 20:41:23 +04:00
self.context = kwargs.pop('context', {})
2014-08-29 19:46:26 +04:00
kwargs.pop('partial', None)
kwargs.pop('many', False)
2014-08-29 19:46:26 +04:00
super(Serializer, self).__init__(*args, **kwargs)
2014-08-29 19:46:26 +04:00
# Every new serializer is created with a clone of the field instances.
# This allows users to dynamically modify the fields on a serializer
# instance without affecting every other serializer class.
self.fields = self.get_fields()
2014-08-29 19:46:26 +04:00
# Setup all the child fields, to provide them with the current context.
for field_name, field in self.fields.items():
2014-08-29 19:46:26 +04:00
field.bind(field_name, self, self)
2014-08-29 19:46:26 +04:00
def get_fields(self):
return copy.deepcopy(self.base_fields)
2014-08-29 19:46:26 +04:00
def bind(self, field_name, parent, root):
# If the serializer is used as a field then when it becomes bound
# it also needs to bind all its child fields.
super(Serializer, self).bind(field_name, parent, root)
for field_name, field in self.fields.items():
2014-08-29 19:46:26 +04:00
field.bind(field_name, self, root)
2014-08-29 19:46:26 +04:00
def get_initial(self):
return {
field.field_name: field.get_initial()
for field in self.fields.values()
}
2014-08-29 19:46:26 +04:00
def get_value(self, dictionary):
# We override the default field access in order to support
# nested HTML forms.
if html.is_html_input(dictionary):
return html.parse_html_dict(dictionary, prefix=self.field_name)
return dictionary.get(self.field_name, empty)
2012-10-24 14:39:17 +04:00
2014-08-29 19:46:26 +04:00
def to_native(self, data):
"""
2014-08-29 19:46:26 +04:00
Dict of native values <- Dict of primitive datatypes.
"""
2014-08-29 19:46:26 +04:00
ret = {}
errors = {}
fields = [field for field in self.fields.values() if not field.read_only]
2014-08-29 19:46:26 +04:00
for field in fields:
primitive_value = field.get_value(data)
try:
validated_value = field.validate(primitive_value)
except ValidationError as exc:
errors[field.field_name] = str(exc)
except SkipField:
pass
else:
set_value(ret, field.source_attrs, validated_value)
2014-08-29 19:46:26 +04:00
if errors:
raise ValidationError(errors)
2014-09-02 20:41:23 +04:00
return self.validate(ret)
2014-08-29 19:46:26 +04:00
def to_primative(self, instance):
"""
2014-08-29 19:46:26 +04:00
Object instance -> Dict of primitive datatypes.
"""
2014-08-29 19:46:26 +04:00
ret = OrderedDict()
fields = [field for field in self.fields.values() if not field.write_only]
2014-08-29 19:46:26 +04:00
for field in fields:
native_value = field.get_attribute(instance)
ret[field.field_name] = field.to_primative(native_value)
2014-08-29 19:46:26 +04:00
return ret
2014-09-02 20:41:23 +04:00
def validate(self, attrs):
return attrs
2014-08-29 19:46:26 +04:00
def __iter__(self):
errors = self.errors if hasattr(self, '_errors') else {}
for field in self.fields.values():
value = self.data.get(field.field_name) if self.data else None
error = errors.get(field.field_name)
yield FieldResult(field, value, error)
2014-08-29 19:46:26 +04:00
class ListSerializer(BaseSerializer):
child = None
initial = []
2014-08-29 19:46:26 +04:00
def __init__(self, *args, **kwargs):
self.child = kwargs.pop('child', copy.deepcopy(self.child))
assert self.child is not None, '`child` is a required argument.'
2014-09-02 20:41:23 +04:00
self.context = kwargs.pop('context', {})
2014-08-29 19:46:26 +04:00
kwargs.pop('partial', None)
2014-08-29 19:46:26 +04:00
super(ListSerializer, self).__init__(*args, **kwargs)
self.child.bind('', self, self)
2014-08-29 19:46:26 +04:00
def bind(self, field_name, parent, root):
# If the list is used as a field then it needs to provide
# the current context to the child serializer.
super(ListSerializer, self).bind(field_name, parent, root)
self.child.bind(field_name, self, root)
2014-08-29 19:46:26 +04:00
def get_value(self, dictionary):
# We override the default field access in order to support
# lists in HTML forms.
if is_html_input(dictionary):
return html.parse_html_list(dictionary, prefix=self.field_name)
return dictionary.get(self.field_name, empty)
2013-03-19 18:26:48 +04:00
2014-08-29 19:46:26 +04:00
def to_native(self, data):
"""
2014-08-29 19:46:26 +04:00
List of dicts of native values <- List of dicts of primitive datatypes.
"""
2014-08-29 19:46:26 +04:00
if html.is_html_input(data):
data = html.parse_html_list(data)
2014-08-29 19:46:26 +04:00
return [self.child.validate(item) for item in data]
2014-08-29 19:46:26 +04:00
def to_primative(self, data):
2013-01-31 00:38:11 +04:00
"""
2014-08-29 19:46:26 +04:00
List of object instances -> List of dicts of primitive datatypes.
2013-01-31 00:38:11 +04:00
"""
2014-08-29 19:46:26 +04:00
return [self.child.to_primative(item) for item in data]
2014-08-29 19:46:26 +04:00
def create(self, attrs_list):
return [self.child.create(attrs) for attrs in attrs_list]
2014-08-29 19:46:26 +04:00
def save(self):
if self.instance is not None:
self.update(self.instance, self.validated_data)
self.instance = self.create(self.validated_data)
return self.instance
2013-03-09 14:21:53 +04:00
2014-08-29 19:46:26 +04:00
def _resolve_model(obj):
"""
Resolve supplied `obj` to a Django model class.
2014-08-29 19:46:26 +04:00
`obj` must be a Django model class itself, or a string
representation of one. Useful in situtations like GH #1225 where
Django may not have resolved a string-based reference to a model in
another model's foreign key definition.
2014-08-29 19:46:26 +04:00
String representations should have the format:
'appname.ModelName'
"""
if isinstance(obj, six.string_types) and len(obj.split('.')) == 2:
app_name, model_name = obj.split('.')
return models.get_model(app_name, model_name)
elif inspect.isclass(obj) and issubclass(obj, models.Model):
return obj
else:
raise ValueError("{0} is not a Django model".format(obj))
2014-08-29 19:46:26 +04:00
class ModelSerializerOptions(object):
"""
Meta class options for ModelSerializer
"""
def __init__(self, meta):
2014-08-29 19:46:26 +04:00
self.model = getattr(meta, 'model')
self.fields = getattr(meta, 'fields', ())
self.depth = getattr(meta, 'depth', 0)
2012-10-04 16:28:14 +04:00
class ModelSerializer(Serializer):
field_mapping = {
models.AutoField: IntegerField,
2014-08-29 19:46:26 +04:00
# models.FloatField: FloatField,
models.IntegerField: IntegerField,
models.PositiveIntegerField: IntegerField,
models.SmallIntegerField: IntegerField,
models.PositiveSmallIntegerField: IntegerField,
2014-09-02 20:41:23 +04:00
models.DateTimeField: DateTimeField,
models.DateField: DateField,
models.TimeField: TimeField,
2014-08-29 19:46:26 +04:00
# models.DecimalField: DecimalField,
2014-09-02 20:41:23 +04:00
models.EmailField: EmailField,
models.CharField: CharField,
2014-09-02 20:41:23 +04:00
models.URLField: URLField,
2014-08-29 19:46:26 +04:00
# models.SlugField: SlugField,
models.TextField: CharField,
models.CommaSeparatedIntegerField: CharField,
models.BooleanField: BooleanField,
models.NullBooleanField: BooleanField,
2014-09-02 20:41:23 +04:00
models.FileField: FileField,
2014-08-29 19:46:26 +04:00
# models.ImageField: ImageField,
}
2014-08-29 19:46:26 +04:00
_options_class = ModelSerializerOptions
def __init__(self, *args, **kwargs):
self.opts = self._options_class(self.Meta)
super(ModelSerializer, self).__init__(*args, **kwargs)
2014-09-02 20:41:23 +04:00
def create(self):
ModelClass = self.opts.model
return ModelClass.objects.create(**self.validated_data)
def update(self, obj):
for attr, value in self.validated_data.items():
setattr(obj, attr, value)
obj.save()
2014-08-29 19:46:26 +04:00
def get_fields(self):
# Get the explicitly declared fields.
fields = copy.deepcopy(self.base_fields)
# Add in the default fields.
for key, val in self.get_default_fields().items():
if key not in fields:
fields[key] = val
# If `fields` is set on the `Meta` class,
# then use only those fields, and in that order.
if self.opts.fields:
fields = OrderedDict([
(key, fields[key]) for key in self.opts.fields
])
return fields
def get_default_fields(self):
"""
Return all the fields that should be serialized for the model.
"""
2012-09-28 18:54:00 +04:00
cls = self.opts.model
2013-09-25 13:30:04 +04:00
opts = cls._meta.concrete_model._meta
2014-08-29 19:46:26 +04:00
ret = OrderedDict()
nested = bool(self.opts.depth)
# Deal with adding the primary key field
pk_field = opts.pk
while pk_field.rel and pk_field.rel.parent_link:
# If model is a child via multitable inheritance, use parent's pk
pk_field = pk_field.rel.to._meta.pk
2014-08-19 16:28:07 +04:00
serializer_pk_field = self.get_pk_field(pk_field)
if serializer_pk_field:
ret[pk_field.name] = serializer_pk_field
# Deal with forward relationships
forward_rels = [field for field in opts.fields if field.serialize]
forward_rels += [field for field in opts.many_to_many if field.serialize]
for model_field in forward_rels:
has_through_model = False
if model_field.rel:
2012-10-09 20:49:04 +04:00
to_many = isinstance(model_field,
models.fields.related.ManyToManyField)
2014-01-13 21:39:22 +04:00
related_model = _resolve_model(model_field.rel.to)
if to_many and not model_field.rel.through._meta.auto_created:
has_through_model = True
if model_field.rel and nested:
2014-08-29 19:46:26 +04:00
field = self.get_nested_field(model_field, related_model, to_many)
elif model_field.rel:
2014-08-29 19:46:26 +04:00
field = self.get_related_field(model_field, related_model, to_many)
else:
field = self.get_field(model_field)
2012-10-03 12:26:15 +04:00
if field:
if has_through_model:
field.read_only = True
ret[model_field.name] = field
# Deal with reverse relationships
if not self.opts.fields:
reverse_rels = []
else:
# Reverse relationships are only included if they are explicitly
# present in the `fields` option on the serializer
reverse_rels = opts.get_all_related_objects()
reverse_rels += opts.get_all_related_many_to_many_objects()
for relation in reverse_rels:
accessor_name = relation.get_accessor_name()
2013-04-29 16:20:15 +04:00
if not self.opts.fields or accessor_name not in self.opts.fields:
continue
related_model = relation.model
to_many = relation.field.rel.multiple
has_through_model = False
is_m2m = isinstance(relation.field,
models.fields.related.ManyToManyField)
2014-08-19 16:28:07 +04:00
if (
is_m2m and
hasattr(relation.field.rel, 'through') and
2014-08-19 16:28:07 +04:00
not relation.field.rel.through._meta.auto_created
):
has_through_model = True
if nested:
field = self.get_nested_field(None, related_model, to_many)
else:
field = self.get_related_field(None, related_model, to_many)
if field:
if has_through_model:
field.read_only = True
ret[accessor_name] = field
return ret
def get_pk_field(self, model_field):
"""
Returns a default instance of the pk field.
"""
return self.get_field(model_field)
def get_nested_field(self, model_field, related_model, to_many):
"""
Creates a default instance of a nested relational field.
2013-04-30 11:24:33 +04:00
Note that model_field will be `None` for reverse relationships.
"""
2012-11-13 15:47:32 +04:00
class NestedModelSerializer(ModelSerializer):
class Meta:
model = related_model
depth = self.opts.depth - 1
return NestedModelSerializer(many=to_many)
def get_related_field(self, model_field, related_model, to_many):
"""
Creates a default instance of a flat relational field.
2013-04-30 11:24:33 +04:00
Note that model_field will be `None` for reverse relationships.
"""
2012-10-04 18:00:23 +04:00
# TODO: filter queryset using:
# .using(db).complex_filter(self.rel.limit_choices_to)
2014-08-29 19:46:26 +04:00
kwargs = {}
# 'queryset': related_model._default_manager,
# 'many': to_many
# }
2012-12-08 01:32:39 +04:00
if model_field:
kwargs['required'] = not(model_field.null or model_field.blank)
2014-09-02 18:07:56 +04:00
# if model_field.help_text is not None:
# kwargs['help_text'] = model_field.help_text
if model_field.verbose_name is not None:
kwargs['label'] = model_field.verbose_name
if not model_field.editable:
kwargs['read_only'] = True
if model_field.verbose_name is not None:
kwargs['label'] = model_field.verbose_name
2014-08-29 19:46:26 +04:00
return IntegerField(**kwargs)
# TODO: return PrimaryKeyRelatedField(**kwargs)
def get_field(self, model_field):
"""
Creates a default instance of a basic non-relational field.
"""
kwargs = {}
if model_field.null or model_field.blank:
kwargs['required'] = False
if isinstance(model_field, models.AutoField) or not model_field.editable:
kwargs['read_only'] = True
if model_field.has_default():
kwargs['default'] = model_field.get_default()
2012-10-26 15:45:52 +04:00
2013-01-29 01:08:40 +04:00
if issubclass(model_field.__class__, models.TextField):
2012-10-26 15:45:52 +04:00
kwargs['widget'] = widgets.Textarea
2012-12-03 22:07:07 +04:00
if model_field.verbose_name is not None:
kwargs['label'] = model_field.verbose_name
2014-08-29 19:46:26 +04:00
# if model_field.help_text is not None:
# kwargs['help_text'] = model_field.help_text
2012-10-26 15:45:52 +04:00
# TODO: TypedChoiceField?
if model_field.flatchoices: # This ModelField contains choices
kwargs['choices'] = model_field.flatchoices
if model_field.null:
kwargs['empty'] = None
2012-10-26 15:45:52 +04:00
return ChoiceField(**kwargs)
# put this below the ChoiceField because min_value isn't a valid initializer
2014-08-29 19:46:26 +04:00
if issubclass(model_field.__class__, models.PositiveIntegerField) or \
issubclass(model_field.__class__, models.PositiveSmallIntegerField):
kwargs['min_value'] = 0
2014-01-12 18:30:26 +04:00
if model_field.null and \
issubclass(model_field.__class__, (models.CharField, models.TextField)):
kwargs['allow_none'] = True
2014-08-29 19:46:26 +04:00
# attribute_dict = {
# models.CharField: ['max_length'],
# models.CommaSeparatedIntegerField: ['max_length'],
# models.DecimalField: ['max_digits', 'decimal_places'],
# models.EmailField: ['max_length'],
# models.FileField: ['max_length'],
# models.ImageField: ['max_length'],
# models.SlugField: ['max_length'],
# models.URLField: ['max_length'],
# }
# if model_field.__class__ in attribute_dict:
# attributes = attribute_dict[model_field.__class__]
# for attribute in attributes:
# kwargs.update({attribute: getattr(model_field, attribute)})
2013-05-18 17:12:54 +04:00
try:
2014-01-12 18:30:26 +04:00
return self.field_mapping[model_field.__class__](**kwargs)
except KeyError:
2014-08-29 19:46:26 +04:00
# TODO: Change this to `return ModelField(model_field=model_field, **kwargs)`
return CharField(**kwargs)
2013-01-18 23:47:57 +04:00
class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
"""
Options for HyperlinkedModelSerializer
"""
def __init__(self, meta):
super(HyperlinkedModelSerializerOptions, self).__init__(meta)
self.view_name = getattr(meta, 'view_name', None)
self.lookup_field = getattr(meta, 'lookup_field', None)
self.url_field_name = getattr(meta, 'url_field_name', api_settings.URL_FIELD_NAME)
class HyperlinkedModelSerializer(ModelSerializer):
_options_class = HyperlinkedModelSerializerOptions
_default_view_name = '%(model_name)s-detail'
2014-09-02 20:41:23 +04:00
_hyperlink_field_class = HyperlinkedRelatedField
_hyperlink_identify_field_class = HyperlinkedIdentityField
def get_default_fields(self):
fields = super(HyperlinkedModelSerializer, self).get_default_fields()
if self.opts.view_name is None:
self.opts.view_name = self._get_default_view_name(self.opts.model)
2014-09-02 20:41:23 +04:00
if self.opts.url_field_name not in fields:
url_field = self._hyperlink_identify_field_class(
view_name=self.opts.view_name,
lookup_field=self.opts.lookup_field
)
ret = fields.__class__()
ret[self.opts.url_field_name] = url_field
ret.update(fields)
fields = ret
return fields
def get_pk_field(self, model_field):
if self.opts.fields and model_field.name in self.opts.fields:
return self.get_field(model_field)
def get_related_field(self, model_field, related_model, to_many):
"""
Creates a default instance of a flat relational field.
"""
# TODO: filter queryset using:
# .using(db).complex_filter(self.rel.limit_choices_to)
2014-08-29 19:46:26 +04:00
# kwargs = {
# 'queryset': related_model._default_manager,
# 'view_name': self._get_default_view_name(related_model),
# 'many': to_many
# }
kwargs = {}
if model_field:
kwargs['required'] = not(model_field.null or model_field.blank)
2014-08-29 19:46:26 +04:00
# if model_field.help_text is not None:
# kwargs['help_text'] = model_field.help_text
if model_field.verbose_name is not None:
kwargs['label'] = model_field.verbose_name
2014-08-29 19:46:26 +04:00
return IntegerField(**kwargs)
# if self.opts.lookup_field:
# kwargs['lookup_field'] = self.opts.lookup_field
2014-08-29 19:46:26 +04:00
# return self._hyperlink_field_class(**kwargs)
def _get_default_view_name(self, model):
"""
Return the view name to use if 'view_name' is not specified in 'Meta'
"""
model_meta = model._meta
format_kwargs = {
'app_label': model_meta.app_label,
'model_name': model_meta.object_name.lower()
}
return self._default_view_name % format_kwargs