django-rest-framework/rest_framework/serializers.py
2014-09-09 17:46:28 +01:00

664 lines
23 KiB
Python

"""
Serializers and ModelSerializers are similar to Forms and ModelForms.
Unlike forms, they are not constrained to dealing with HTML output, and
form encoded input.
Serialization in REST framework is a two-phase process:
1. Serializers marshal between complex types like model instances, and
python primitives.
2. The process of marshalling between python primitives and request and
response content is handled by parsers and renderers.
"""
from django.core import validators
from django.core.exceptions import ValidationError
from django.db import models
from django.utils import six
from collections import namedtuple, OrderedDict
from rest_framework.fields import empty, set_value, Field, SkipField
from rest_framework.settings import api_settings
from rest_framework.utils import html, modelinfo, representation
import copy
# Note: We do the following so that users of the framework can use this style:
#
# example_field = serializers.CharField(...)
#
# This helps keep the separation between model fields, form fields, and
# serializer fields more explicit.
from rest_framework.relations import * # NOQA
from rest_framework.fields import * # NOQA
FieldResult = namedtuple('FieldResult', ['field', 'value', 'error'])
class BaseSerializer(Field):
"""
The BaseSerializer class provides a minimal class which may be used
for writing custom serializer implementations.
"""
def __init__(self, instance=None, data=None, **kwargs):
super(BaseSerializer, self).__init__(**kwargs)
self.instance = instance
self._initial_data = data
def to_native(self, data):
raise NotImplementedError('`to_native()` must be implemented.')
def to_primative(self, instance):
raise NotImplementedError('`to_primative()` must be implemented.')
def update(self, instance, attrs):
raise NotImplementedError('`update()` must be implemented.')
def create(self, attrs):
raise NotImplementedError('`create()` must be implemented.')
def save(self, extras=None):
attrs = self.validated_data
if extras is not None:
attrs = dict(list(attrs.items()) + list(extras.items()))
if self.instance is not None:
self.update(self.instance, attrs)
else:
self.instance = self.create(attrs)
return self.instance
def is_valid(self, raise_exception=False):
if not hasattr(self, '_validated_data'):
try:
self._validated_data = self.to_native(self._initial_data)
except ValidationError as exc:
self._validated_data = {}
self._errors = exc.message_dict
else:
self._errors = {}
if self._errors and raise_exception:
raise ValidationError(self._errors)
return not bool(self._errors)
@property
def data(self):
if not hasattr(self, '_data'):
if self.instance is not None:
self._data = self.to_primative(self.instance)
elif self._initial_data is not None:
self._data = {
field_name: field.get_value(self._initial_data)
for field_name, field in self.fields.items()
}
else:
self._data = self.get_initial()
return self._data
@property
def errors(self):
if not hasattr(self, '_errors'):
msg = 'You must call `.is_valid()` before accessing `.errors`.'
raise AssertionError(msg)
return self._errors
@property
def validated_data(self):
if not hasattr(self, '_validated_data'):
msg = 'You must call `.is_valid()` before accessing `.validated_data`.'
raise AssertionError(msg)
return self._validated_data
class SerializerMetaclass(type):
"""
This metaclass sets a dictionary named `base_fields` on the class.
Any fields included as attributes on either the class or it's superclasses
will be include in the `base_fields` dictionary.
"""
@classmethod
def _get_fields(cls, bases, attrs):
fields = [(field_name, attrs.pop(field_name))
for field_name, obj in list(attrs.items())
if isinstance(obj, Field)]
fields.sort(key=lambda x: x[1]._creation_counter)
# If this class is subclassing another Serializer, add that Serializer's
# fields. Note that we loop over the bases in *reverse*. This is necessary
# in order to maintain the correct order of fields.
for base in bases[::-1]:
if hasattr(base, 'base_fields'):
fields = list(base.base_fields.items()) + fields
return OrderedDict(fields)
def __new__(cls, name, bases, attrs):
attrs['base_fields'] = cls._get_fields(bases, attrs)
return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs)
@six.add_metaclass(SerializerMetaclass)
class Serializer(BaseSerializer):
def __new__(cls, *args, **kwargs):
if kwargs.pop('many', False):
kwargs['child'] = cls()
return ListSerializer(*args, **kwargs)
return super(Serializer, cls).__new__(cls, *args, **kwargs)
def __init__(self, *args, **kwargs):
self.context = kwargs.pop('context', {})
kwargs.pop('partial', None)
kwargs.pop('many', False)
super(Serializer, self).__init__(*args, **kwargs)
# Every new serializer is created with a clone of the field instances.
# This allows users to dynamically modify the fields on a serializer
# instance without affecting every other serializer class.
self.fields = self.get_fields()
# Setup all the child fields, to provide them with the current context.
for field_name, field in self.fields.items():
field.bind(field_name, self, self)
def get_fields(self):
return copy.deepcopy(self.base_fields)
def bind(self, field_name, parent, root):
# If the serializer is used as a field then when it becomes bound
# it also needs to bind all its child fields.
super(Serializer, self).bind(field_name, parent, root)
for field_name, field in self.fields.items():
field.bind(field_name, self, root)
def get_initial(self):
return {
field.field_name: field.get_initial()
for field in self.fields.values()
}
def get_value(self, dictionary):
# We override the default field access in order to support
# nested HTML forms.
if html.is_html_input(dictionary):
return html.parse_html_dict(dictionary, prefix=self.field_name)
return dictionary.get(self.field_name, empty)
def to_native(self, data):
"""
Dict of native values <- Dict of primitive datatypes.
"""
if not isinstance(data, dict):
raise ValidationError({'non_field_errors': ['Invalid data']})
ret = {}
errors = {}
fields = [field for field in self.fields.values() if not field.read_only]
for field in fields:
validate_method = getattr(self, 'validate_' + field.field_name, None)
primitive_value = field.get_value(data)
try:
validated_value = field.validate(primitive_value)
if validate_method is not None:
validated_value = validate_method(validated_value)
except ValidationError as exc:
errors[field.field_name] = exc.messages
except SkipField:
pass
else:
set_value(ret, field.source_attrs, validated_value)
if errors:
raise ValidationError(errors)
try:
return self.validate(ret)
except ValidationError, exc:
raise ValidationError({'non_field_errors': exc.messages})
def to_primative(self, instance):
"""
Object instance -> Dict of primitive datatypes.
"""
ret = OrderedDict()
fields = [field for field in self.fields.values() if not field.write_only]
for field in fields:
native_value = field.get_attribute(instance)
ret[field.field_name] = field.to_primative(native_value)
return ret
def validate(self, attrs):
return attrs
def __iter__(self):
errors = self.errors if hasattr(self, '_errors') else {}
for field in self.fields.values():
value = self.data.get(field.field_name) if self.data else None
error = errors.get(field.field_name)
yield FieldResult(field, value, error)
def __repr__(self):
return representation.serializer_repr(self, indent=1)
class ListSerializer(BaseSerializer):
child = None
initial = []
def __init__(self, *args, **kwargs):
self.child = kwargs.pop('child', copy.deepcopy(self.child))
assert self.child is not None, '`child` is a required argument.'
self.context = kwargs.pop('context', {})
kwargs.pop('partial', None)
super(ListSerializer, self).__init__(*args, **kwargs)
self.child.bind('', self, self)
def bind(self, field_name, parent, root):
# If the list is used as a field then it needs to provide
# the current context to the child serializer.
super(ListSerializer, self).bind(field_name, parent, root)
self.child.bind(field_name, self, root)
def get_value(self, dictionary):
# We override the default field access in order to support
# lists in HTML forms.
if is_html_input(dictionary):
return html.parse_html_list(dictionary, prefix=self.field_name)
return dictionary.get(self.field_name, empty)
def to_native(self, data):
"""
List of dicts of native values <- List of dicts of primitive datatypes.
"""
if html.is_html_input(data):
data = html.parse_html_list(data)
return [self.child.validate(item) for item in data]
def to_primative(self, data):
"""
List of object instances -> List of dicts of primitive datatypes.
"""
return [self.child.to_primative(item) for item in data]
def create(self, attrs_list):
return [self.child.create(attrs) for attrs in attrs_list]
def save(self):
if self.instance is not None:
self.update(self.instance, self.validated_data)
self.instance = self.create(self.validated_data)
return self.instance
def __repr__(self):
return representation.list_repr(self, indent=1)
class ModelSerializerOptions(object):
"""
Meta class options for ModelSerializer
"""
def __init__(self, meta):
self.model = getattr(meta, 'model')
self.fields = getattr(meta, 'fields', ())
self.depth = getattr(meta, 'depth', 0)
class ModelSerializer(Serializer):
field_mapping = {
models.AutoField: IntegerField,
models.BigIntegerField: IntegerField,
models.BooleanField: BooleanField,
models.CharField: CharField,
models.CommaSeparatedIntegerField: CharField,
models.DateField: DateField,
models.DateTimeField: DateTimeField,
models.DecimalField: DecimalField,
models.EmailField: EmailField,
models.FileField: FileField,
models.FloatField: FloatField,
models.IntegerField: IntegerField,
models.NullBooleanField: BooleanField,
models.PositiveIntegerField: IntegerField,
models.PositiveSmallIntegerField: IntegerField,
models.SlugField: SlugField,
models.SmallIntegerField: IntegerField,
models.TextField: CharField,
models.TimeField: TimeField,
models.URLField: URLField,
# models.ImageField: ImageField,
}
_options_class = ModelSerializerOptions
def __init__(self, *args, **kwargs):
self.opts = self._options_class(self.Meta)
super(ModelSerializer, self).__init__(*args, **kwargs)
def create(self, attrs):
ModelClass = self.opts.model
return ModelClass.objects.create(**attrs)
def update(self, obj, attrs):
for attr, value in attrs.items():
setattr(obj, attr, value)
obj.save()
def get_fields(self):
# Get the explicitly declared fields.
fields = copy.deepcopy(self.base_fields)
# Add in the default fields.
for key, val in self.get_default_fields().items():
if key not in fields:
fields[key] = val
# If `fields` is set on the `Meta` class,
# then use only those fields, and in that order.
if self.opts.fields:
fields = OrderedDict([
(key, fields[key]) for key in self.opts.fields
])
return fields
def get_default_fields(self):
"""
Return all the fields that should be serialized for the model.
"""
info = modelinfo.get_field_info(self.opts.model)
ret = OrderedDict()
serializer_pk_field = self.get_pk_field(info.pk)
if serializer_pk_field:
ret[info.pk.name] = serializer_pk_field
# Regular fields
for field_name, field in info.fields.items():
ret[field_name] = self.get_field(field)
# Forward relations
for field_name, relation_info in info.forward_relations.items():
if self.opts.depth:
ret[field_name] = self.get_nested_field(*relation_info)
else:
ret[field_name] = self.get_related_field(*relation_info)
# Reverse relations
for accessor_name, relation_info in info.reverse_relations.items():
if accessor_name in self.opts.fields:
if self.opts.depth:
ret[field_name] = self.get_nested_field(*relation_info)
else:
ret[field_name] = self.get_related_field(*relation_info)
return ret
def get_pk_field(self, model_field):
"""
Returns a default instance of the pk field.
"""
return self.get_field(model_field)
def get_nested_field(self, model_field, related_model, to_many, has_through_model):
"""
Creates a default instance of a nested relational field.
Note that model_field will be `None` for reverse relationships.
"""
class NestedModelSerializer(ModelSerializer):
class Meta:
model = related_model
depth = self.opts.depth - 1
kwargs = {'read_only': True}
if to_many:
kwargs['many'] = True
return NestedModelSerializer(**kwargs)
def get_related_field(self, model_field, related_model, to_many, has_through_model):
"""
Creates a default instance of a flat relational field.
Note that model_field will be `None` for reverse relationships.
"""
kwargs = {
'queryset': related_model._default_manager,
}
if to_many:
kwargs['many'] = True
if has_through_model:
kwargs['read_only'] = True
kwargs.pop('queryset', None)
if model_field:
if model_field.null or model_field.blank:
kwargs['required'] = False
# if model_field.help_text is not None:
# kwargs['help_text'] = model_field.help_text
if model_field.verbose_name is not None:
kwargs['label'] = model_field.verbose_name
if not model_field.editable:
kwargs['read_only'] = True
kwargs.pop('queryset', None)
return PrimaryKeyRelatedField(**kwargs)
def get_field(self, model_field):
"""
Creates a default instance of a basic non-relational field.
"""
kwargs = {}
validator_kwarg = model_field.validators
if model_field.null or model_field.blank:
kwargs['required'] = False
if model_field.verbose_name is not None:
kwargs['label'] = model_field.verbose_name
if isinstance(model_field, models.AutoField) or not model_field.editable:
kwargs['read_only'] = True
# Read only implies that the field is not required.
# We have a cleaner repr on the instance if we don't set it.
kwargs.pop('required', None)
if model_field.has_default():
kwargs['default'] = model_field.get_default()
# Having a default implies that the field is not required.
# We have a cleaner repr on the instance if we don't set it.
kwargs.pop('required', None)
# Ensure that max_length is passed explicitly as a keyword arg,
# rather than as a validator.
max_length = getattr(model_field, 'max_length', None)
if max_length is not None:
kwargs['max_length'] = max_length
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.MaxLengthValidator)
]
# Ensure that min_length is passed explicitly as a keyword arg,
# rather than as a validator.
min_length = getattr(model_field, 'min_length', None)
if min_length is not None:
kwargs['min_length'] = min_length
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.MinLengthValidator)
]
# Ensure that max_value is passed explicitly as a keyword arg,
# rather than as a validator.
max_value = next((
validator.limit_value for validator in validator_kwarg
if isinstance(validator, validators.MaxValueValidator)
), None)
if max_value is not None:
kwargs['max_value'] = max_value
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.MaxValueValidator)
]
# Ensure that max_value is passed explicitly as a keyword arg,
# rather than as a validator.
min_value = next((
validator.limit_value for validator in validator_kwarg
if isinstance(validator, validators.MinValueValidator)
), None)
if min_value is not None:
kwargs['min_value'] = min_value
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.MinValueValidator)
]
# URLField does not need to include the URLValidator argument,
# as it is explicitly added in.
if isinstance(model_field, models.URLField):
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.URLValidator)
]
# EmailField does not need to include the validate_email argument,
# as it is explicitly added in.
if isinstance(model_field, models.EmailField):
validator_kwarg = [
validator for validator in validator_kwarg
if validator is not validators.validate_email
]
# SlugField do not need to include the 'validate_slug' argument,
if isinstance(model_field, models.SlugField):
validator_kwarg = [
validator for validator in validator_kwarg
if validator is not validators.validate_slug
]
max_digits = getattr(model_field, 'max_digits', None)
if max_digits is not None:
kwargs['max_digits'] = max_digits
decimal_places = getattr(model_field, 'decimal_places', None)
if decimal_places is not None:
kwargs['decimal_places'] = decimal_places
if validator_kwarg:
kwargs['validators'] = validator_kwarg
# if issubclass(model_field.__class__, models.TextField):
# kwargs['widget'] = widgets.Textarea
# if model_field.help_text is not None:
# kwargs['help_text'] = model_field.help_text
# TODO: TypedChoiceField?
if model_field.flatchoices: # This ModelField contains choices
kwargs['choices'] = model_field.flatchoices
if model_field.null:
kwargs['empty'] = None
return ChoiceField(**kwargs)
if model_field.null and \
issubclass(model_field.__class__, (models.CharField, models.TextField)):
kwargs['allow_none'] = True
try:
return self.field_mapping[model_field.__class__](**kwargs)
except KeyError:
return ModelField(model_field=model_field, **kwargs)
class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
"""
Options for HyperlinkedModelSerializer
"""
def __init__(self, meta):
super(HyperlinkedModelSerializerOptions, self).__init__(meta)
self.view_name = getattr(meta, 'view_name', None)
self.lookup_field = getattr(meta, 'lookup_field', None)
class HyperlinkedModelSerializer(ModelSerializer):
_options_class = HyperlinkedModelSerializerOptions
def get_default_fields(self):
fields = super(HyperlinkedModelSerializer, self).get_default_fields()
if self.opts.view_name is None:
self.opts.view_name = self.get_default_view_name(self.opts.model)
url_field_name = api_settings.URL_FIELD_NAME
if url_field_name not in fields:
ret = fields.__class__()
ret[url_field_name] = self.get_url_field()
ret.update(fields)
fields = ret
return fields
def get_pk_field(self, model_field):
if self.opts.fields and model_field.name in self.opts.fields:
return self.get_field(model_field)
def get_url_field(self):
kwargs = {
'view_name': self.get_default_view_name(self.opts.model)
}
if self.opts.lookup_field:
kwargs['lookup_field'] = self.opts.lookup_field
return HyperlinkedIdentityField(**kwargs)
def get_related_field(self, model_field, related_model, to_many, has_through_model):
"""
Creates a default instance of a flat relational field.
"""
kwargs = {
'queryset': related_model._default_manager,
'view_name': self.get_default_view_name(related_model),
}
if to_many:
kwargs['many'] = True
if has_through_model:
kwargs['read_only'] = True
kwargs.pop('queryset', None)
if model_field:
if model_field.null or model_field.blank:
kwargs['required'] = False
# if model_field.help_text is not None:
# kwargs['help_text'] = model_field.help_text
if model_field.verbose_name is not None:
kwargs['label'] = model_field.verbose_name
if not model_field.editable:
kwargs['read_only'] = True
kwargs.pop('queryset', None)
return HyperlinkedRelatedField(**kwargs)
def get_default_view_name(self, model):
"""
Return the view name to use for related models.
"""
return '%(model_name)s-detail' % {
'app_label': model._meta.app_label,
'model_name': model._meta.object_name.lower()
}