mirror of
https://github.com/encode/django-rest-framework.git
synced 2024-11-26 03:23:59 +03:00
Added serializers and fields
This commit is contained in:
parent
9ea12d1412
commit
ecd3733c5e
446
djangorestframework/fields.py
Normal file
446
djangorestframework/fields.py
Normal file
|
@ -0,0 +1,446 @@
|
|||
import copy
|
||||
import datetime
|
||||
import inspect
|
||||
import warnings
|
||||
|
||||
from django.core import validators
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.conf import settings
|
||||
from django.db import DEFAULT_DB_ALIAS
|
||||
from django.db.models.related import RelatedObject
|
||||
from django.utils import timezone
|
||||
from django.utils.dateparse import parse_date, parse_datetime
|
||||
from django.utils.encoding import is_protected_type, smart_unicode
|
||||
from django.utils.translation import ugettext_lazy as _
|
||||
|
||||
|
||||
def is_simple_callable(obj):
|
||||
"""
|
||||
True if the object is a callable that takes no arguments.
|
||||
"""
|
||||
return (
|
||||
(inspect.isfunction(obj) and not inspect.getargspec(obj)[0]) or
|
||||
(inspect.ismethod(obj) and len(inspect.getargspec(obj)[0]) <= 1)
|
||||
)
|
||||
|
||||
|
||||
class Field(object):
|
||||
creation_counter = 0
|
||||
default_validators = []
|
||||
default_error_messages = {
|
||||
'required': _('This field is required.'),
|
||||
'invalid': _('Invalid value.'),
|
||||
}
|
||||
empty = ''
|
||||
|
||||
def __init__(self, source=None, readonly=False, required=None,
|
||||
validators=[], error_messages=None):
|
||||
self.parent = None
|
||||
|
||||
self.creation_counter = Field.creation_counter
|
||||
Field.creation_counter += 1
|
||||
|
||||
self.source = source
|
||||
self.readonly = readonly
|
||||
self.required = not(readonly)
|
||||
|
||||
messages = {}
|
||||
for c in reversed(self.__class__.__mro__):
|
||||
messages.update(getattr(c, 'default_error_messages', {}))
|
||||
messages.update(error_messages or {})
|
||||
self.error_messages = messages
|
||||
|
||||
self.validators = self.default_validators + validators
|
||||
|
||||
def initialize(self, parent, model_field=None):
|
||||
"""
|
||||
Called to set up a field prior to field_to_native or field_from_native.
|
||||
|
||||
parent - The parent serializer.
|
||||
model_field - The model field this field corrosponds to, if one exists.
|
||||
"""
|
||||
self.parent = parent
|
||||
self.root = parent.root or parent
|
||||
self.context = self.root.context
|
||||
if model_field:
|
||||
self.model_field = model_field
|
||||
|
||||
def validate(self, value):
|
||||
pass
|
||||
# if value in validators.EMPTY_VALUES and self.required:
|
||||
# raise ValidationError(self.error_messages['required'])
|
||||
|
||||
def run_validators(self, value):
|
||||
if value in validators.EMPTY_VALUES:
|
||||
return
|
||||
errors = []
|
||||
for v in self.validators:
|
||||
try:
|
||||
v(value)
|
||||
except ValidationError as e:
|
||||
if hasattr(e, 'code') and e.code in self.error_messages:
|
||||
message = self.error_messages[e.code]
|
||||
if e.params:
|
||||
message = message % e.params
|
||||
errors.append(message)
|
||||
else:
|
||||
errors.extend(e.messages)
|
||||
if errors:
|
||||
raise ValidationError(errors)
|
||||
|
||||
def field_from_native(self, data, field_name, into):
|
||||
"""
|
||||
Given a dictionary and a field name, updates the dictionary `into`,
|
||||
with the field and it's deserialized value.
|
||||
"""
|
||||
if self.readonly:
|
||||
return
|
||||
|
||||
try:
|
||||
native = data[field_name]
|
||||
except KeyError:
|
||||
return # TODO Consider validation behaviour, 'required' opt etc...
|
||||
|
||||
value = self.from_native(native)
|
||||
if self.source == '*':
|
||||
if value:
|
||||
into.update(value)
|
||||
else:
|
||||
self.validate(value)
|
||||
self.run_validators(value)
|
||||
into[self.source or field_name] = value
|
||||
|
||||
def from_native(self, value):
|
||||
"""
|
||||
Reverts a simple representation back to the field's value.
|
||||
"""
|
||||
if hasattr(self, 'model_field'):
|
||||
try:
|
||||
return self.model_field.rel.to._meta.get_field(self.model_field.rel.field_name).to_python(value)
|
||||
except:
|
||||
return self.model_field.to_python(value)
|
||||
return value
|
||||
|
||||
def field_to_native(self, obj, field_name):
|
||||
"""
|
||||
Given and object and a field name, returns the value that should be
|
||||
serialized for that field.
|
||||
"""
|
||||
if obj is None:
|
||||
return self.empty
|
||||
|
||||
if self.source == '*':
|
||||
return self.to_native(obj)
|
||||
|
||||
self.obj = obj # Need to hang onto this in the case of model fields
|
||||
if hasattr(self, 'model_field'):
|
||||
return self.to_native(self.model_field._get_val_from_obj(obj))
|
||||
|
||||
return self.to_native(getattr(obj, self.source or field_name))
|
||||
|
||||
def to_native(self, value):
|
||||
"""
|
||||
Converts the field's value into it's simple representation.
|
||||
"""
|
||||
if is_simple_callable(value):
|
||||
value = value()
|
||||
|
||||
if is_protected_type(value):
|
||||
return value
|
||||
elif hasattr(self, 'model_field'):
|
||||
return self.model_field.value_to_string(self.obj)
|
||||
return smart_unicode(value)
|
||||
|
||||
def attributes(self):
|
||||
"""
|
||||
Returns a dictionary of attributes to be used when serializing to xml.
|
||||
"""
|
||||
try:
|
||||
return {
|
||||
"type": self.model_field.get_internal_type()
|
||||
}
|
||||
except AttributeError:
|
||||
return {}
|
||||
|
||||
|
||||
class RelatedField(Field):
|
||||
"""
|
||||
A base class for model related fields or related managers.
|
||||
|
||||
Subclass this and override `convert` to define custom behaviour when
|
||||
serializing related objects.
|
||||
"""
|
||||
|
||||
def field_to_native(self, obj, field_name):
|
||||
obj = getattr(obj, field_name)
|
||||
if obj.__class__.__name__ in ('RelatedManager', 'ManyRelatedManager'):
|
||||
return [self.to_native(item) for item in obj.all()]
|
||||
return self.to_native(obj)
|
||||
|
||||
def attributes(self):
|
||||
try:
|
||||
return {
|
||||
"rel": self.model_field.rel.__class__.__name__,
|
||||
"to": smart_unicode(self.model_field.rel.to._meta)
|
||||
}
|
||||
except AttributeError:
|
||||
return {}
|
||||
|
||||
|
||||
class PrimaryKeyRelatedField(RelatedField):
|
||||
"""
|
||||
Serializes a model related field or related manager to a pk value.
|
||||
"""
|
||||
|
||||
# Note the we use ModelRelatedField's implementation, as we want to get the
|
||||
# raw database value directly, since that won't involve another
|
||||
# database lookup.
|
||||
#
|
||||
# An alternative implementation would simply be this...
|
||||
#
|
||||
# class PrimaryKeyRelatedField(RelatedField):
|
||||
# def to_native(self, obj):
|
||||
# return obj.pk
|
||||
|
||||
def to_native(self, pk):
|
||||
"""
|
||||
Simply returns the object's pk. You can subclass this method to
|
||||
provide different serialization behavior of the pk.
|
||||
(For example returning a URL based on the model's pk.)
|
||||
"""
|
||||
return pk
|
||||
|
||||
def field_to_native(self, obj, field_name):
|
||||
try:
|
||||
obj = obj.serializable_value(field_name)
|
||||
except AttributeError:
|
||||
field = obj._meta.get_field_by_name(field_name)[0]
|
||||
obj = getattr(obj, field_name)
|
||||
if obj.__class__.__name__ == 'RelatedManager':
|
||||
return [self.to_native(item.pk) for item in obj.all()]
|
||||
elif isinstance(field, RelatedObject):
|
||||
return self.to_native(obj.pk)
|
||||
raise
|
||||
if obj.__class__.__name__ == 'ManyRelatedManager':
|
||||
return [self.to_native(item.pk) for item in obj.all()]
|
||||
return self.to_native(obj)
|
||||
|
||||
def field_from_native(self, data, field_name, into):
|
||||
value = data.get(field_name)
|
||||
if hasattr(value, '__iter__'):
|
||||
into[field_name] = [self.from_native(item) for item in value]
|
||||
else:
|
||||
into[field_name + '_id'] = self.from_native(value)
|
||||
|
||||
|
||||
class NaturalKeyRelatedField(RelatedField):
|
||||
"""
|
||||
Serializes a model related field or related manager to a natural key value.
|
||||
"""
|
||||
is_natural_key = True # XML renderer handles these differently
|
||||
|
||||
def to_native(self, obj):
|
||||
if hasattr(obj, 'natural_key'):
|
||||
return obj.natural_key()
|
||||
return obj
|
||||
|
||||
def field_from_native(self, data, field_name, into):
|
||||
value = data.get(field_name)
|
||||
into[self.model_field.attname] = self.from_native(value)
|
||||
|
||||
def from_native(self, value):
|
||||
# TODO: Support 'using' : db = options.pop('using', DEFAULT_DB_ALIAS)
|
||||
manager = self.model_field.rel.to._default_manager
|
||||
manager = manager.db_manager(DEFAULT_DB_ALIAS)
|
||||
return manager.get_by_natural_key(*value).pk
|
||||
|
||||
|
||||
class BooleanField(Field):
|
||||
default_error_messages = {
|
||||
'invalid': _(u"'%s' value must be either True or False."),
|
||||
}
|
||||
|
||||
def from_native(self, value):
|
||||
if value in (True, False):
|
||||
# if value is 1 or 0 than it's equal to True or False, but we want
|
||||
# to return a true bool for semantic reasons.
|
||||
return bool(value)
|
||||
if value in ('t', 'True', '1'):
|
||||
return True
|
||||
if value in ('f', 'False', '0'):
|
||||
return False
|
||||
raise ValidationError(self.error_messages['invalid'] % value)
|
||||
|
||||
|
||||
class CharField(Field):
|
||||
def __init__(self, max_length=None, min_length=None, *args, **kwargs):
|
||||
self.max_length, self.min_length = max_length, min_length
|
||||
super(CharField, self).__init__(*args, **kwargs)
|
||||
if min_length is not None:
|
||||
self.validators.append(validators.MinLengthValidator(min_length))
|
||||
if max_length is not None:
|
||||
self.validators.append(validators.MaxLengthValidator(max_length))
|
||||
|
||||
def from_native(self, value):
|
||||
if isinstance(value, basestring) or value is None:
|
||||
return value
|
||||
return smart_unicode(value)
|
||||
|
||||
|
||||
class EmailField(CharField):
|
||||
default_error_messages = {
|
||||
'invalid': _('Enter a valid e-mail address.'),
|
||||
}
|
||||
default_validators = [validators.validate_email]
|
||||
|
||||
def from_native(self, value):
|
||||
return super(EmailField, self).from_native(value).strip()
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
result = copy.copy(self)
|
||||
memo[id(self)] = result
|
||||
#result.widget = copy.deepcopy(self.widget, memo)
|
||||
result.validators = self.validators[:]
|
||||
return result
|
||||
|
||||
|
||||
class DateField(Field):
|
||||
default_error_messages = {
|
||||
'invalid': _(u"'%s' value has an invalid date format. It must be "
|
||||
u"in YYYY-MM-DD format."),
|
||||
'invalid_date': _(u"'%s' value has the correct format (YYYY-MM-DD) "
|
||||
u"but it is an invalid date."),
|
||||
}
|
||||
empty = None
|
||||
|
||||
def from_native(self, value):
|
||||
if value is None:
|
||||
return value
|
||||
if isinstance(value, datetime.datetime):
|
||||
if settings.USE_TZ and timezone.is_aware(value):
|
||||
# Convert aware datetimes to the default time zone
|
||||
# before casting them to dates (#17742).
|
||||
default_timezone = timezone.get_default_timezone()
|
||||
value = timezone.make_naive(value, default_timezone)
|
||||
return value.date()
|
||||
if isinstance(value, datetime.date):
|
||||
return value
|
||||
|
||||
try:
|
||||
parsed = parse_date(value)
|
||||
if parsed is not None:
|
||||
return parsed
|
||||
except ValueError:
|
||||
msg = self.error_messages['invalid_date'] % value
|
||||
raise ValidationError(msg)
|
||||
|
||||
msg = self.error_messages['invalid'] % value
|
||||
raise ValidationError(msg)
|
||||
|
||||
|
||||
class DateTimeField(Field):
|
||||
default_error_messages = {
|
||||
'invalid': _(u"'%s' value has an invalid format. It must be in "
|
||||
u"YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ] format."),
|
||||
'invalid_date': _(u"'%s' value has the correct format "
|
||||
u"(YYYY-MM-DD) but it is an invalid date."),
|
||||
'invalid_datetime': _(u"'%s' value has the correct format "
|
||||
u"(YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ]) "
|
||||
u"but it is an invalid date/time."),
|
||||
}
|
||||
empty = None
|
||||
|
||||
def from_native(self, value):
|
||||
if value is None:
|
||||
return value
|
||||
if isinstance(value, datetime.datetime):
|
||||
return value
|
||||
if isinstance(value, datetime.date):
|
||||
value = datetime.datetime(value.year, value.month, value.day)
|
||||
if settings.USE_TZ:
|
||||
# For backwards compatibility, interpret naive datetimes in
|
||||
# local time. This won't work during DST change, but we can't
|
||||
# do much about it, so we let the exceptions percolate up the
|
||||
# call stack.
|
||||
warnings.warn(u"DateTimeField received a naive datetime (%s)"
|
||||
u" while time zone support is active." % value,
|
||||
RuntimeWarning)
|
||||
default_timezone = timezone.get_default_timezone()
|
||||
value = timezone.make_aware(value, default_timezone)
|
||||
return value
|
||||
|
||||
try:
|
||||
parsed = parse_datetime(value)
|
||||
if parsed is not None:
|
||||
return parsed
|
||||
except ValueError:
|
||||
msg = self.error_messages['invalid_datetime'] % value
|
||||
raise ValidationError(msg)
|
||||
|
||||
try:
|
||||
parsed = parse_date(value)
|
||||
if parsed is not None:
|
||||
return datetime.datetime(parsed.year, parsed.month, parsed.day)
|
||||
except ValueError:
|
||||
msg = self.error_messages['invalid_date'] % value
|
||||
raise ValidationError(msg)
|
||||
|
||||
msg = self.error_messages['invalid'] % value
|
||||
raise ValidationError(msg)
|
||||
|
||||
|
||||
class IntegerField(Field):
|
||||
default_error_messages = {
|
||||
'invalid': _('Enter a whole number.'),
|
||||
'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'),
|
||||
'min_value': _('Ensure this value is greater than or equal to %(limit_value)s.'),
|
||||
}
|
||||
|
||||
def __init__(self, max_value=None, min_value=None, *args, **kwargs):
|
||||
self.max_value, self.min_value = max_value, min_value
|
||||
super(IntegerField, self).__init__(*args, **kwargs)
|
||||
|
||||
if max_value is not None:
|
||||
self.validators.append(validators.MaxValueValidator(max_value))
|
||||
if min_value is not None:
|
||||
self.validators.append(validators.MinValueValidator(min_value))
|
||||
|
||||
def from_native(self, value):
|
||||
if value in validators.EMPTY_VALUES:
|
||||
return None
|
||||
try:
|
||||
value = int(str(value))
|
||||
except (ValueError, TypeError):
|
||||
raise ValidationError(self.error_messages['invalid'])
|
||||
return value
|
||||
|
||||
|
||||
class FloatField(Field):
|
||||
default_error_messages = {
|
||||
'invalid': _("'%s' value must be a float."),
|
||||
}
|
||||
|
||||
def from_native(self, value):
|
||||
if value is None:
|
||||
return value
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
msg = self.error_messages['invalid'] % value
|
||||
raise ValidationError(msg)
|
||||
|
||||
# field_mapping = {
|
||||
# models.AutoField: IntegerField,
|
||||
# models.BooleanField: BooleanField,
|
||||
# models.CharField: CharField,
|
||||
# models.DateTimeField: DateTimeField,
|
||||
# models.DateField: DateField,
|
||||
# models.BigIntegerField: IntegerField,
|
||||
# models.IntegerField: IntegerField,
|
||||
# models.PositiveIntegerField: IntegerField,
|
||||
# models.FloatField: FloatField
|
||||
# }
|
||||
|
||||
|
||||
# def modelfield_to_serializerfield(field):
|
||||
# return field_mapping.get(type(field), Field)
|
|
@ -57,7 +57,7 @@ class BaseParser(object):
|
|||
"""
|
||||
return media_type_matches(self.media_type, content_type)
|
||||
|
||||
def parse(self, stream, meta, upload_handlers):
|
||||
def parse(self, stream, **opts):
|
||||
"""
|
||||
Given a *stream* to read from, return the deserialized output.
|
||||
Should return a 2-tuple of (data, files).
|
||||
|
@ -72,7 +72,7 @@ class JSONParser(BaseParser):
|
|||
|
||||
media_type = 'application/json'
|
||||
|
||||
def parse(self, stream, meta, upload_handlers):
|
||||
def parse(self, stream, **opts):
|
||||
"""
|
||||
Returns a 2-tuple of `(data, files)`.
|
||||
|
||||
|
@ -92,7 +92,7 @@ class YAMLParser(BaseParser):
|
|||
|
||||
media_type = 'application/yaml'
|
||||
|
||||
def parse(self, stream, meta, upload_handlers):
|
||||
def parse(self, stream, **opts):
|
||||
"""
|
||||
Returns a 2-tuple of `(data, files)`.
|
||||
|
||||
|
@ -112,7 +112,7 @@ class PlainTextParser(BaseParser):
|
|||
|
||||
media_type = 'text/plain'
|
||||
|
||||
def parse(self, stream, meta, upload_handlers):
|
||||
def parse(self, stream, **opts):
|
||||
"""
|
||||
Returns a 2-tuple of `(data, files)`.
|
||||
|
||||
|
@ -129,7 +129,7 @@ class FormParser(BaseParser):
|
|||
|
||||
media_type = 'application/x-www-form-urlencoded'
|
||||
|
||||
def parse(self, stream, meta, upload_handlers):
|
||||
def parse(self, stream, **opts):
|
||||
"""
|
||||
Returns a 2-tuple of `(data, files)`.
|
||||
|
||||
|
@ -147,13 +147,15 @@ class MultiPartParser(BaseParser):
|
|||
|
||||
media_type = 'multipart/form-data'
|
||||
|
||||
def parse(self, stream, meta, upload_handlers):
|
||||
def parse(self, stream, **opts):
|
||||
"""
|
||||
Returns a 2-tuple of `(data, files)`.
|
||||
|
||||
`data` will be a :class:`QueryDict` containing all the form parameters.
|
||||
`files` will be a :class:`QueryDict` containing all the form files.
|
||||
"""
|
||||
meta = opts['meta']
|
||||
upload_handlers = opts['upload_handlers']
|
||||
try:
|
||||
parser = DjangoMultiPartParser(meta, stream, upload_handlers)
|
||||
return parser.parse()
|
||||
|
@ -168,7 +170,7 @@ class XMLParser(BaseParser):
|
|||
|
||||
media_type = 'application/xml'
|
||||
|
||||
def parse(self, stream, meta, upload_handlers):
|
||||
def parse(self, stream, **opts):
|
||||
"""
|
||||
Returns a 2-tuple of `(data, files)`.
|
||||
|
||||
|
|
|
@ -214,7 +214,8 @@ class Request(object):
|
|||
|
||||
for parser in self.get_parsers():
|
||||
if parser.can_handle_request(self.content_type):
|
||||
return parser.parse(self.stream, self.META, self.upload_handlers)
|
||||
return parser.parse(self.stream, meta=self.META,
|
||||
upload_handlers=self.upload_handlers)
|
||||
|
||||
raise UnsupportedMediaType(self._content_type)
|
||||
|
||||
|
|
348
djangorestframework/serializers.py
Normal file
348
djangorestframework/serializers.py
Normal file
|
@ -0,0 +1,348 @@
|
|||
from decimal import Decimal
|
||||
from django.core.serializers.base import DeserializedObject
|
||||
from django.utils.datastructures import SortedDict
|
||||
import copy
|
||||
import datetime
|
||||
import types
|
||||
from djangorestframework.fields import *
|
||||
|
||||
|
||||
class DictWithMetadata(dict):
|
||||
"""
|
||||
A dict-like object, that can have additional properties attached.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class SortedDictWithMetadata(SortedDict, DictWithMetadata):
|
||||
"""
|
||||
A sorted dict-like object, that can have additional properties attached.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class RecursionOccured(BaseException):
|
||||
pass
|
||||
|
||||
|
||||
def _is_protected_type(obj):
|
||||
"""
|
||||
True if the object is a native datatype that does not need to
|
||||
be serialized further.
|
||||
"""
|
||||
return isinstance(obj, (
|
||||
types.NoneType,
|
||||
int, long,
|
||||
datetime.datetime, datetime.date, datetime.time,
|
||||
float, Decimal,
|
||||
basestring)
|
||||
)
|
||||
|
||||
|
||||
def _get_declared_fields(bases, attrs):
|
||||
"""
|
||||
Create a list of serializer field instances from the passed in 'attrs',
|
||||
plus any fields on the base classes (in 'bases').
|
||||
|
||||
Note that all fields from the base classes are used.
|
||||
"""
|
||||
fields = [(field_name, attrs.pop(field_name))
|
||||
for field_name, obj in attrs.items()
|
||||
if isinstance(obj, Field)]
|
||||
fields.sort(key=lambda x: x[1].creation_counter)
|
||||
|
||||
# If this class is subclassing another Serializer, add that Serializer's
|
||||
# fields. Note that we loop over the bases in *reverse*. This is necessary
|
||||
# in order to the correct order of fields.
|
||||
for base in bases[::-1]:
|
||||
if hasattr(base, 'base_fields'):
|
||||
fields = base.base_fields.items() + fields
|
||||
|
||||
return SortedDict(fields)
|
||||
|
||||
|
||||
class SerializerMetaclass(type):
|
||||
def __new__(cls, name, bases, attrs):
|
||||
attrs['base_fields'] = _get_declared_fields(bases, attrs)
|
||||
return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs)
|
||||
|
||||
|
||||
class SerializerOptions(object):
|
||||
"""
|
||||
Meta class options for ModelSerializer
|
||||
"""
|
||||
def __init__(self, meta):
|
||||
self.nested = getattr(meta, 'nested', False)
|
||||
self.fields = getattr(meta, 'fields', ())
|
||||
self.exclude = getattr(meta, 'exclude', ())
|
||||
|
||||
|
||||
class BaseSerializer(Field):
|
||||
class Meta(object):
|
||||
pass
|
||||
|
||||
_options_class = SerializerOptions
|
||||
_dict_class = SortedDictWithMetadata # Set to unsorted dict for backwards compatability with unsorted implementations.
|
||||
|
||||
def __init__(self, data=None, instance=None, context=None, **kwargs):
|
||||
super(BaseSerializer, self).__init__(**kwargs)
|
||||
self.fields = copy.deepcopy(self.base_fields)
|
||||
self.opts = self._options_class(self.Meta)
|
||||
self.parent = None
|
||||
self.root = None
|
||||
|
||||
self.stack = []
|
||||
self.context = context or {}
|
||||
|
||||
self.init_data = data
|
||||
self.instance = instance
|
||||
|
||||
self._data = None
|
||||
self._errors = None
|
||||
|
||||
#####
|
||||
# Methods to determine which fields to use when (de)serializing objects.
|
||||
|
||||
def default_fields(self, serialize, obj=None, data=None, nested=False):
|
||||
"""
|
||||
Return the complete set of default fields for the object, as a dict.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_fields(self, serialize, obj=None, data=None, nested=False):
|
||||
"""
|
||||
Returns the complete set of fields for the object as a dict.
|
||||
|
||||
This will be the set of any explicitly declared fields,
|
||||
plus the set of fields returned by get_default_fields().
|
||||
"""
|
||||
ret = SortedDict()
|
||||
|
||||
# Get the explicitly declared fields
|
||||
for key, field in self.fields.items():
|
||||
ret[key] = field
|
||||
# Determine if the declared field corrosponds to a model field.
|
||||
try:
|
||||
if key == 'pk':
|
||||
model_field = obj._meta.pk
|
||||
else:
|
||||
model_field = obj._meta.get_field_by_name(key)[0]
|
||||
except:
|
||||
model_field = None
|
||||
# Set up the field
|
||||
field.initialize(parent=self, model_field=model_field)
|
||||
|
||||
# Add in the default fields
|
||||
fields = self.default_fields(serialize, obj, data, nested)
|
||||
for key, val in fields.items():
|
||||
if key not in ret:
|
||||
ret[key] = val
|
||||
|
||||
# If 'fields' is specified, use those fields, in that order.
|
||||
if self.opts.fields:
|
||||
new = SortedDict()
|
||||
for key in self.opts.fields:
|
||||
new[key] = ret[key]
|
||||
ret = new
|
||||
|
||||
# Remove anything in 'exclude'
|
||||
if self.opts.exclude:
|
||||
for key in self.opts.exclude:
|
||||
ret.pop(key, None)
|
||||
|
||||
return ret
|
||||
|
||||
#####
|
||||
# Field methods - used when the serializer class is itself used as a field.
|
||||
|
||||
def initialize(self, parent, model_field=None):
|
||||
"""
|
||||
Same behaviour as usual Field, except that we need to keep track
|
||||
of state so that we can deal with handling maximum depth and recursion.
|
||||
"""
|
||||
super(BaseSerializer, self).initialize(parent, model_field)
|
||||
self.stack = parent.stack[:]
|
||||
if parent.opts.nested and not isinstance(parent.opts.nested, bool):
|
||||
self.opts.nested = parent.opts.nested - 1
|
||||
else:
|
||||
self.opts.nested = parent.opts.nested
|
||||
|
||||
#####
|
||||
# Methods to convert or revert from objects <--> primative representations.
|
||||
|
||||
def get_field_key(self, field_name):
|
||||
"""
|
||||
Return the key that should be used for a given field.
|
||||
"""
|
||||
return field_name
|
||||
|
||||
def convert_object(self, obj):
|
||||
"""
|
||||
Core of serialization.
|
||||
Convert an object into a dictionary of serialized field values.
|
||||
"""
|
||||
if obj in self.stack and not self.source == '*':
|
||||
raise RecursionOccured()
|
||||
self.stack.append(obj)
|
||||
|
||||
ret = self._dict_class()
|
||||
ret.fields = {}
|
||||
|
||||
fields = self.get_fields(serialize=True, obj=obj, nested=self.opts.nested)
|
||||
for field_name, field in fields.items():
|
||||
key = self.get_field_key(field_name)
|
||||
try:
|
||||
value = field.field_to_native(obj, field_name)
|
||||
except RecursionOccured:
|
||||
field = self.get_fields(serialize=True, obj=obj, nested=False)[field_name]
|
||||
value = field.field_to_native(obj, field_name)
|
||||
ret[key] = value
|
||||
ret.fields[key] = field
|
||||
return ret
|
||||
|
||||
def restore_fields(self, data):
|
||||
"""
|
||||
Core of deserialization, together with `restore_object`.
|
||||
Converts a dictionary of data into a dictionary of deserialized fields.
|
||||
"""
|
||||
fields = self.get_fields(serialize=False, data=data, nested=self.opts.nested)
|
||||
reverted_data = {}
|
||||
for field_name, field in fields.items():
|
||||
try:
|
||||
field.field_from_native(data, field_name, reverted_data)
|
||||
except ValidationError as err:
|
||||
self._errors[field_name] = list(err.messages)
|
||||
|
||||
return reverted_data
|
||||
|
||||
def restore_object(self, attrs, instance=None):
|
||||
"""
|
||||
Deserialize a dictionary of attributes into an object instance.
|
||||
You should override this method to control how deserialized objects
|
||||
are instantiated.
|
||||
"""
|
||||
if instance is not None:
|
||||
instance.update(attrs)
|
||||
return instance
|
||||
return attrs
|
||||
|
||||
def to_native(self, obj):
|
||||
"""
|
||||
Serialize objects -> primatives.
|
||||
"""
|
||||
if isinstance(obj, dict):
|
||||
return dict([(key, self.to_native(val))
|
||||
for (key, val) in obj.items()])
|
||||
elif hasattr(obj, '__iter__'):
|
||||
return (self.to_native(item) for item in obj)
|
||||
return self.convert_object(obj)
|
||||
|
||||
def from_native(self, data):
|
||||
"""
|
||||
Deserialize primatives -> objects.
|
||||
"""
|
||||
if hasattr(data, '__iter__') and not isinstance(data, dict):
|
||||
# TODO: error data when deserializing lists
|
||||
return (self.from_native(item) for item in data)
|
||||
self._errors = {}
|
||||
attrs = self.restore_fields(data)
|
||||
if not self._errors:
|
||||
return self.restore_object(attrs, instance=getattr(self, 'instance', None))
|
||||
|
||||
@property
|
||||
def errors(self):
|
||||
"""
|
||||
Run deserialization and return error data,
|
||||
setting self.object if no errors occured.
|
||||
"""
|
||||
if self._errors is None:
|
||||
obj = self.from_native(self.init_data)
|
||||
if not self._errors:
|
||||
self.object = obj
|
||||
return self._errors
|
||||
|
||||
def is_valid(self):
|
||||
return not self.errors
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
if self._data is None:
|
||||
self._data = self.to_native(self.instance)
|
||||
return self._data
|
||||
|
||||
|
||||
class Serializer(BaseSerializer):
|
||||
__metaclass__ = SerializerMetaclass
|
||||
|
||||
|
||||
class ModelSerializerOptions(SerializerOptions):
|
||||
"""
|
||||
Meta class options for ModelSerializer
|
||||
"""
|
||||
def __init__(self, meta):
|
||||
super(ModelSerializerOptions, self).__init__(meta)
|
||||
self.model = getattr(meta, 'model', None)
|
||||
|
||||
|
||||
class ModelSerializer(RelatedField, Serializer):
|
||||
"""
|
||||
A serializer that deals with model instances and querysets.
|
||||
"""
|
||||
_options_class = ModelSerializerOptions
|
||||
|
||||
def default_fields(self, serialize, obj=None, data=None, nested=False):
|
||||
"""
|
||||
Return all the fields that should be serialized for the model.
|
||||
"""
|
||||
if serialize:
|
||||
cls = obj.__class__
|
||||
else:
|
||||
cls = self.opts.model
|
||||
|
||||
opts = cls._meta.concrete_model._meta
|
||||
pk_field = opts.pk
|
||||
while pk_field.rel:
|
||||
pk_field = pk_field.rel.to._meta.pk
|
||||
fields = [pk_field]
|
||||
fields += [field for field in opts.fields if field.serialize]
|
||||
fields += [field for field in opts.many_to_many if field.serialize]
|
||||
|
||||
ret = SortedDict()
|
||||
for model_field in fields:
|
||||
if model_field.rel and nested:
|
||||
field = self.get_nested_field(model_field)
|
||||
elif model_field.rel:
|
||||
field = self.get_related_field(model_field)
|
||||
else:
|
||||
field = self.get_field(model_field)
|
||||
field.initialize(parent=self, model_field=model_field)
|
||||
ret[model_field.name] = field
|
||||
return ret
|
||||
|
||||
def get_nested_field(self, model_field):
|
||||
"""
|
||||
Creates a default instance of a nested relational field.
|
||||
"""
|
||||
return ModelSerializer()
|
||||
|
||||
def get_related_field(self, model_field):
|
||||
"""
|
||||
Creates a default instance of a flat relational field.
|
||||
"""
|
||||
return PrimaryKeyRelatedField()
|
||||
|
||||
def get_field(self, model_field):
|
||||
"""
|
||||
Creates a default instance of a basic field.
|
||||
"""
|
||||
return Field()
|
||||
|
||||
def restore_object(self, attrs, instance=None):
|
||||
"""
|
||||
Restore the model instance.
|
||||
"""
|
||||
m2m_data = {}
|
||||
for field in self.opts.model._meta.many_to_many:
|
||||
if field.name in attrs:
|
||||
m2m_data[field.name] = attrs.pop(field.name)
|
||||
return DeserializedObject(self.opts.model(**attrs), m2m_data)
|
|
@ -153,7 +153,7 @@ class TestFormParser(TestCase):
|
|||
parser = FormParser()
|
||||
|
||||
stream = StringIO(self.string)
|
||||
(data, files) = parser.parse(stream, {}, [])
|
||||
(data, files) = parser.parse(stream)
|
||||
|
||||
self.assertEqual(Form(data).is_valid(), True)
|
||||
|
||||
|
@ -203,10 +203,10 @@ class TestXMLParser(TestCase):
|
|||
|
||||
def test_parse(self):
|
||||
parser = XMLParser()
|
||||
(data, files) = parser.parse(self._input, {}, [])
|
||||
(data, files) = parser.parse(self._input)
|
||||
self.assertEqual(data, self._data)
|
||||
|
||||
def test_complex_data_parse(self):
|
||||
parser = XMLParser()
|
||||
(data, files) = parser.parse(self._complex_data_input, {}, [])
|
||||
(data, files) = parser.parse(self._complex_data_input)
|
||||
self.assertEqual(data, self._complex_data)
|
||||
|
|
|
@ -380,7 +380,7 @@ class XMLRendererTestCase(TestCase):
|
|||
content = StringIO(renderer.render(self._complex_data, 'application/xml'))
|
||||
|
||||
parser = XMLParser()
|
||||
complex_data_out, dummy = parser.parse(content, {}, [])
|
||||
complex_data_out, dummy = parser.parse(content)
|
||||
error_msg = "complex data differs!IN:\n %s \n\n OUT:\n %s" % (repr(self._complex_data), repr(complex_data_out))
|
||||
self.assertEqual(self._complex_data, complex_data_out, error_msg)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user