Added serializers and fields

This commit is contained in:
Tom Christie 2012-08-28 15:46:38 +01:00
parent 9ea12d1412
commit ecd3733c5e
6 changed files with 809 additions and 12 deletions

View File

@ -0,0 +1,446 @@
import copy
import datetime
import inspect
import warnings
from django.core import validators
from django.core.exceptions import ValidationError
from django.conf import settings
from django.db import DEFAULT_DB_ALIAS
from django.db.models.related import RelatedObject
from django.utils import timezone
from django.utils.dateparse import parse_date, parse_datetime
from django.utils.encoding import is_protected_type, smart_unicode
from django.utils.translation import ugettext_lazy as _
def is_simple_callable(obj):
"""
True if the object is a callable that takes no arguments.
"""
return (
(inspect.isfunction(obj) and not inspect.getargspec(obj)[0]) or
(inspect.ismethod(obj) and len(inspect.getargspec(obj)[0]) <= 1)
)
class Field(object):
creation_counter = 0
default_validators = []
default_error_messages = {
'required': _('This field is required.'),
'invalid': _('Invalid value.'),
}
empty = ''
def __init__(self, source=None, readonly=False, required=None,
validators=[], error_messages=None):
self.parent = None
self.creation_counter = Field.creation_counter
Field.creation_counter += 1
self.source = source
self.readonly = readonly
self.required = not(readonly)
messages = {}
for c in reversed(self.__class__.__mro__):
messages.update(getattr(c, 'default_error_messages', {}))
messages.update(error_messages or {})
self.error_messages = messages
self.validators = self.default_validators + validators
def initialize(self, parent, model_field=None):
"""
Called to set up a field prior to field_to_native or field_from_native.
parent - The parent serializer.
model_field - The model field this field corrosponds to, if one exists.
"""
self.parent = parent
self.root = parent.root or parent
self.context = self.root.context
if model_field:
self.model_field = model_field
def validate(self, value):
pass
# if value in validators.EMPTY_VALUES and self.required:
# raise ValidationError(self.error_messages['required'])
def run_validators(self, value):
if value in validators.EMPTY_VALUES:
return
errors = []
for v in self.validators:
try:
v(value)
except ValidationError as e:
if hasattr(e, 'code') and e.code in self.error_messages:
message = self.error_messages[e.code]
if e.params:
message = message % e.params
errors.append(message)
else:
errors.extend(e.messages)
if errors:
raise ValidationError(errors)
def field_from_native(self, data, field_name, into):
"""
Given a dictionary and a field name, updates the dictionary `into`,
with the field and it's deserialized value.
"""
if self.readonly:
return
try:
native = data[field_name]
except KeyError:
return # TODO Consider validation behaviour, 'required' opt etc...
value = self.from_native(native)
if self.source == '*':
if value:
into.update(value)
else:
self.validate(value)
self.run_validators(value)
into[self.source or field_name] = value
def from_native(self, value):
"""
Reverts a simple representation back to the field's value.
"""
if hasattr(self, 'model_field'):
try:
return self.model_field.rel.to._meta.get_field(self.model_field.rel.field_name).to_python(value)
except:
return self.model_field.to_python(value)
return value
def field_to_native(self, obj, field_name):
"""
Given and object and a field name, returns the value that should be
serialized for that field.
"""
if obj is None:
return self.empty
if self.source == '*':
return self.to_native(obj)
self.obj = obj # Need to hang onto this in the case of model fields
if hasattr(self, 'model_field'):
return self.to_native(self.model_field._get_val_from_obj(obj))
return self.to_native(getattr(obj, self.source or field_name))
def to_native(self, value):
"""
Converts the field's value into it's simple representation.
"""
if is_simple_callable(value):
value = value()
if is_protected_type(value):
return value
elif hasattr(self, 'model_field'):
return self.model_field.value_to_string(self.obj)
return smart_unicode(value)
def attributes(self):
"""
Returns a dictionary of attributes to be used when serializing to xml.
"""
try:
return {
"type": self.model_field.get_internal_type()
}
except AttributeError:
return {}
class RelatedField(Field):
"""
A base class for model related fields or related managers.
Subclass this and override `convert` to define custom behaviour when
serializing related objects.
"""
def field_to_native(self, obj, field_name):
obj = getattr(obj, field_name)
if obj.__class__.__name__ in ('RelatedManager', 'ManyRelatedManager'):
return [self.to_native(item) for item in obj.all()]
return self.to_native(obj)
def attributes(self):
try:
return {
"rel": self.model_field.rel.__class__.__name__,
"to": smart_unicode(self.model_field.rel.to._meta)
}
except AttributeError:
return {}
class PrimaryKeyRelatedField(RelatedField):
"""
Serializes a model related field or related manager to a pk value.
"""
# Note the we use ModelRelatedField's implementation, as we want to get the
# raw database value directly, since that won't involve another
# database lookup.
#
# An alternative implementation would simply be this...
#
# class PrimaryKeyRelatedField(RelatedField):
# def to_native(self, obj):
# return obj.pk
def to_native(self, pk):
"""
Simply returns the object's pk. You can subclass this method to
provide different serialization behavior of the pk.
(For example returning a URL based on the model's pk.)
"""
return pk
def field_to_native(self, obj, field_name):
try:
obj = obj.serializable_value(field_name)
except AttributeError:
field = obj._meta.get_field_by_name(field_name)[0]
obj = getattr(obj, field_name)
if obj.__class__.__name__ == 'RelatedManager':
return [self.to_native(item.pk) for item in obj.all()]
elif isinstance(field, RelatedObject):
return self.to_native(obj.pk)
raise
if obj.__class__.__name__ == 'ManyRelatedManager':
return [self.to_native(item.pk) for item in obj.all()]
return self.to_native(obj)
def field_from_native(self, data, field_name, into):
value = data.get(field_name)
if hasattr(value, '__iter__'):
into[field_name] = [self.from_native(item) for item in value]
else:
into[field_name + '_id'] = self.from_native(value)
class NaturalKeyRelatedField(RelatedField):
"""
Serializes a model related field or related manager to a natural key value.
"""
is_natural_key = True # XML renderer handles these differently
def to_native(self, obj):
if hasattr(obj, 'natural_key'):
return obj.natural_key()
return obj
def field_from_native(self, data, field_name, into):
value = data.get(field_name)
into[self.model_field.attname] = self.from_native(value)
def from_native(self, value):
# TODO: Support 'using' : db = options.pop('using', DEFAULT_DB_ALIAS)
manager = self.model_field.rel.to._default_manager
manager = manager.db_manager(DEFAULT_DB_ALIAS)
return manager.get_by_natural_key(*value).pk
class BooleanField(Field):
default_error_messages = {
'invalid': _(u"'%s' value must be either True or False."),
}
def from_native(self, value):
if value in (True, False):
# if value is 1 or 0 than it's equal to True or False, but we want
# to return a true bool for semantic reasons.
return bool(value)
if value in ('t', 'True', '1'):
return True
if value in ('f', 'False', '0'):
return False
raise ValidationError(self.error_messages['invalid'] % value)
class CharField(Field):
def __init__(self, max_length=None, min_length=None, *args, **kwargs):
self.max_length, self.min_length = max_length, min_length
super(CharField, self).__init__(*args, **kwargs)
if min_length is not None:
self.validators.append(validators.MinLengthValidator(min_length))
if max_length is not None:
self.validators.append(validators.MaxLengthValidator(max_length))
def from_native(self, value):
if isinstance(value, basestring) or value is None:
return value
return smart_unicode(value)
class EmailField(CharField):
default_error_messages = {
'invalid': _('Enter a valid e-mail address.'),
}
default_validators = [validators.validate_email]
def from_native(self, value):
return super(EmailField, self).from_native(value).strip()
def __deepcopy__(self, memo):
result = copy.copy(self)
memo[id(self)] = result
#result.widget = copy.deepcopy(self.widget, memo)
result.validators = self.validators[:]
return result
class DateField(Field):
default_error_messages = {
'invalid': _(u"'%s' value has an invalid date format. It must be "
u"in YYYY-MM-DD format."),
'invalid_date': _(u"'%s' value has the correct format (YYYY-MM-DD) "
u"but it is an invalid date."),
}
empty = None
def from_native(self, value):
if value is None:
return value
if isinstance(value, datetime.datetime):
if settings.USE_TZ and timezone.is_aware(value):
# Convert aware datetimes to the default time zone
# before casting them to dates (#17742).
default_timezone = timezone.get_default_timezone()
value = timezone.make_naive(value, default_timezone)
return value.date()
if isinstance(value, datetime.date):
return value
try:
parsed = parse_date(value)
if parsed is not None:
return parsed
except ValueError:
msg = self.error_messages['invalid_date'] % value
raise ValidationError(msg)
msg = self.error_messages['invalid'] % value
raise ValidationError(msg)
class DateTimeField(Field):
default_error_messages = {
'invalid': _(u"'%s' value has an invalid format. It must be in "
u"YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ] format."),
'invalid_date': _(u"'%s' value has the correct format "
u"(YYYY-MM-DD) but it is an invalid date."),
'invalid_datetime': _(u"'%s' value has the correct format "
u"(YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ]) "
u"but it is an invalid date/time."),
}
empty = None
def from_native(self, value):
if value is None:
return value
if isinstance(value, datetime.datetime):
return value
if isinstance(value, datetime.date):
value = datetime.datetime(value.year, value.month, value.day)
if settings.USE_TZ:
# For backwards compatibility, interpret naive datetimes in
# local time. This won't work during DST change, but we can't
# do much about it, so we let the exceptions percolate up the
# call stack.
warnings.warn(u"DateTimeField received a naive datetime (%s)"
u" while time zone support is active." % value,
RuntimeWarning)
default_timezone = timezone.get_default_timezone()
value = timezone.make_aware(value, default_timezone)
return value
try:
parsed = parse_datetime(value)
if parsed is not None:
return parsed
except ValueError:
msg = self.error_messages['invalid_datetime'] % value
raise ValidationError(msg)
try:
parsed = parse_date(value)
if parsed is not None:
return datetime.datetime(parsed.year, parsed.month, parsed.day)
except ValueError:
msg = self.error_messages['invalid_date'] % value
raise ValidationError(msg)
msg = self.error_messages['invalid'] % value
raise ValidationError(msg)
class IntegerField(Field):
default_error_messages = {
'invalid': _('Enter a whole number.'),
'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'),
'min_value': _('Ensure this value is greater than or equal to %(limit_value)s.'),
}
def __init__(self, max_value=None, min_value=None, *args, **kwargs):
self.max_value, self.min_value = max_value, min_value
super(IntegerField, self).__init__(*args, **kwargs)
if max_value is not None:
self.validators.append(validators.MaxValueValidator(max_value))
if min_value is not None:
self.validators.append(validators.MinValueValidator(min_value))
def from_native(self, value):
if value in validators.EMPTY_VALUES:
return None
try:
value = int(str(value))
except (ValueError, TypeError):
raise ValidationError(self.error_messages['invalid'])
return value
class FloatField(Field):
default_error_messages = {
'invalid': _("'%s' value must be a float."),
}
def from_native(self, value):
if value is None:
return value
try:
return float(value)
except (TypeError, ValueError):
msg = self.error_messages['invalid'] % value
raise ValidationError(msg)
# field_mapping = {
# models.AutoField: IntegerField,
# models.BooleanField: BooleanField,
# models.CharField: CharField,
# models.DateTimeField: DateTimeField,
# models.DateField: DateField,
# models.BigIntegerField: IntegerField,
# models.IntegerField: IntegerField,
# models.PositiveIntegerField: IntegerField,
# models.FloatField: FloatField
# }
# def modelfield_to_serializerfield(field):
# return field_mapping.get(type(field), Field)

View File

@ -57,7 +57,7 @@ class BaseParser(object):
"""
return media_type_matches(self.media_type, content_type)
def parse(self, stream, meta, upload_handlers):
def parse(self, stream, **opts):
"""
Given a *stream* to read from, return the deserialized output.
Should return a 2-tuple of (data, files).
@ -72,7 +72,7 @@ class JSONParser(BaseParser):
media_type = 'application/json'
def parse(self, stream, meta, upload_handlers):
def parse(self, stream, **opts):
"""
Returns a 2-tuple of `(data, files)`.
@ -92,7 +92,7 @@ class YAMLParser(BaseParser):
media_type = 'application/yaml'
def parse(self, stream, meta, upload_handlers):
def parse(self, stream, **opts):
"""
Returns a 2-tuple of `(data, files)`.
@ -112,7 +112,7 @@ class PlainTextParser(BaseParser):
media_type = 'text/plain'
def parse(self, stream, meta, upload_handlers):
def parse(self, stream, **opts):
"""
Returns a 2-tuple of `(data, files)`.
@ -129,7 +129,7 @@ class FormParser(BaseParser):
media_type = 'application/x-www-form-urlencoded'
def parse(self, stream, meta, upload_handlers):
def parse(self, stream, **opts):
"""
Returns a 2-tuple of `(data, files)`.
@ -147,13 +147,15 @@ class MultiPartParser(BaseParser):
media_type = 'multipart/form-data'
def parse(self, stream, meta, upload_handlers):
def parse(self, stream, **opts):
"""
Returns a 2-tuple of `(data, files)`.
`data` will be a :class:`QueryDict` containing all the form parameters.
`files` will be a :class:`QueryDict` containing all the form files.
"""
meta = opts['meta']
upload_handlers = opts['upload_handlers']
try:
parser = DjangoMultiPartParser(meta, stream, upload_handlers)
return parser.parse()
@ -168,7 +170,7 @@ class XMLParser(BaseParser):
media_type = 'application/xml'
def parse(self, stream, meta, upload_handlers):
def parse(self, stream, **opts):
"""
Returns a 2-tuple of `(data, files)`.

View File

@ -214,7 +214,8 @@ class Request(object):
for parser in self.get_parsers():
if parser.can_handle_request(self.content_type):
return parser.parse(self.stream, self.META, self.upload_handlers)
return parser.parse(self.stream, meta=self.META,
upload_handlers=self.upload_handlers)
raise UnsupportedMediaType(self._content_type)

View File

@ -0,0 +1,348 @@
from decimal import Decimal
from django.core.serializers.base import DeserializedObject
from django.utils.datastructures import SortedDict
import copy
import datetime
import types
from djangorestframework.fields import *
class DictWithMetadata(dict):
"""
A dict-like object, that can have additional properties attached.
"""
pass
class SortedDictWithMetadata(SortedDict, DictWithMetadata):
"""
A sorted dict-like object, that can have additional properties attached.
"""
pass
class RecursionOccured(BaseException):
pass
def _is_protected_type(obj):
"""
True if the object is a native datatype that does not need to
be serialized further.
"""
return isinstance(obj, (
types.NoneType,
int, long,
datetime.datetime, datetime.date, datetime.time,
float, Decimal,
basestring)
)
def _get_declared_fields(bases, attrs):
"""
Create a list of serializer field instances from the passed in 'attrs',
plus any fields on the base classes (in 'bases').
Note that all fields from the base classes are used.
"""
fields = [(field_name, attrs.pop(field_name))
for field_name, obj in attrs.items()
if isinstance(obj, Field)]
fields.sort(key=lambda x: x[1].creation_counter)
# If this class is subclassing another Serializer, add that Serializer's
# fields. Note that we loop over the bases in *reverse*. This is necessary
# in order to the correct order of fields.
for base in bases[::-1]:
if hasattr(base, 'base_fields'):
fields = base.base_fields.items() + fields
return SortedDict(fields)
class SerializerMetaclass(type):
def __new__(cls, name, bases, attrs):
attrs['base_fields'] = _get_declared_fields(bases, attrs)
return super(SerializerMetaclass, cls).__new__(cls, name, bases, attrs)
class SerializerOptions(object):
"""
Meta class options for ModelSerializer
"""
def __init__(self, meta):
self.nested = getattr(meta, 'nested', False)
self.fields = getattr(meta, 'fields', ())
self.exclude = getattr(meta, 'exclude', ())
class BaseSerializer(Field):
class Meta(object):
pass
_options_class = SerializerOptions
_dict_class = SortedDictWithMetadata # Set to unsorted dict for backwards compatability with unsorted implementations.
def __init__(self, data=None, instance=None, context=None, **kwargs):
super(BaseSerializer, self).__init__(**kwargs)
self.fields = copy.deepcopy(self.base_fields)
self.opts = self._options_class(self.Meta)
self.parent = None
self.root = None
self.stack = []
self.context = context or {}
self.init_data = data
self.instance = instance
self._data = None
self._errors = None
#####
# Methods to determine which fields to use when (de)serializing objects.
def default_fields(self, serialize, obj=None, data=None, nested=False):
"""
Return the complete set of default fields for the object, as a dict.
"""
return {}
def get_fields(self, serialize, obj=None, data=None, nested=False):
"""
Returns the complete set of fields for the object as a dict.
This will be the set of any explicitly declared fields,
plus the set of fields returned by get_default_fields().
"""
ret = SortedDict()
# Get the explicitly declared fields
for key, field in self.fields.items():
ret[key] = field
# Determine if the declared field corrosponds to a model field.
try:
if key == 'pk':
model_field = obj._meta.pk
else:
model_field = obj._meta.get_field_by_name(key)[0]
except:
model_field = None
# Set up the field
field.initialize(parent=self, model_field=model_field)
# Add in the default fields
fields = self.default_fields(serialize, obj, data, nested)
for key, val in fields.items():
if key not in ret:
ret[key] = val
# If 'fields' is specified, use those fields, in that order.
if self.opts.fields:
new = SortedDict()
for key in self.opts.fields:
new[key] = ret[key]
ret = new
# Remove anything in 'exclude'
if self.opts.exclude:
for key in self.opts.exclude:
ret.pop(key, None)
return ret
#####
# Field methods - used when the serializer class is itself used as a field.
def initialize(self, parent, model_field=None):
"""
Same behaviour as usual Field, except that we need to keep track
of state so that we can deal with handling maximum depth and recursion.
"""
super(BaseSerializer, self).initialize(parent, model_field)
self.stack = parent.stack[:]
if parent.opts.nested and not isinstance(parent.opts.nested, bool):
self.opts.nested = parent.opts.nested - 1
else:
self.opts.nested = parent.opts.nested
#####
# Methods to convert or revert from objects <--> primative representations.
def get_field_key(self, field_name):
"""
Return the key that should be used for a given field.
"""
return field_name
def convert_object(self, obj):
"""
Core of serialization.
Convert an object into a dictionary of serialized field values.
"""
if obj in self.stack and not self.source == '*':
raise RecursionOccured()
self.stack.append(obj)
ret = self._dict_class()
ret.fields = {}
fields = self.get_fields(serialize=True, obj=obj, nested=self.opts.nested)
for field_name, field in fields.items():
key = self.get_field_key(field_name)
try:
value = field.field_to_native(obj, field_name)
except RecursionOccured:
field = self.get_fields(serialize=True, obj=obj, nested=False)[field_name]
value = field.field_to_native(obj, field_name)
ret[key] = value
ret.fields[key] = field
return ret
def restore_fields(self, data):
"""
Core of deserialization, together with `restore_object`.
Converts a dictionary of data into a dictionary of deserialized fields.
"""
fields = self.get_fields(serialize=False, data=data, nested=self.opts.nested)
reverted_data = {}
for field_name, field in fields.items():
try:
field.field_from_native(data, field_name, reverted_data)
except ValidationError as err:
self._errors[field_name] = list(err.messages)
return reverted_data
def restore_object(self, attrs, instance=None):
"""
Deserialize a dictionary of attributes into an object instance.
You should override this method to control how deserialized objects
are instantiated.
"""
if instance is not None:
instance.update(attrs)
return instance
return attrs
def to_native(self, obj):
"""
Serialize objects -> primatives.
"""
if isinstance(obj, dict):
return dict([(key, self.to_native(val))
for (key, val) in obj.items()])
elif hasattr(obj, '__iter__'):
return (self.to_native(item) for item in obj)
return self.convert_object(obj)
def from_native(self, data):
"""
Deserialize primatives -> objects.
"""
if hasattr(data, '__iter__') and not isinstance(data, dict):
# TODO: error data when deserializing lists
return (self.from_native(item) for item in data)
self._errors = {}
attrs = self.restore_fields(data)
if not self._errors:
return self.restore_object(attrs, instance=getattr(self, 'instance', None))
@property
def errors(self):
"""
Run deserialization and return error data,
setting self.object if no errors occured.
"""
if self._errors is None:
obj = self.from_native(self.init_data)
if not self._errors:
self.object = obj
return self._errors
def is_valid(self):
return not self.errors
@property
def data(self):
if self._data is None:
self._data = self.to_native(self.instance)
return self._data
class Serializer(BaseSerializer):
__metaclass__ = SerializerMetaclass
class ModelSerializerOptions(SerializerOptions):
"""
Meta class options for ModelSerializer
"""
def __init__(self, meta):
super(ModelSerializerOptions, self).__init__(meta)
self.model = getattr(meta, 'model', None)
class ModelSerializer(RelatedField, Serializer):
"""
A serializer that deals with model instances and querysets.
"""
_options_class = ModelSerializerOptions
def default_fields(self, serialize, obj=None, data=None, nested=False):
"""
Return all the fields that should be serialized for the model.
"""
if serialize:
cls = obj.__class__
else:
cls = self.opts.model
opts = cls._meta.concrete_model._meta
pk_field = opts.pk
while pk_field.rel:
pk_field = pk_field.rel.to._meta.pk
fields = [pk_field]
fields += [field for field in opts.fields if field.serialize]
fields += [field for field in opts.many_to_many if field.serialize]
ret = SortedDict()
for model_field in fields:
if model_field.rel and nested:
field = self.get_nested_field(model_field)
elif model_field.rel:
field = self.get_related_field(model_field)
else:
field = self.get_field(model_field)
field.initialize(parent=self, model_field=model_field)
ret[model_field.name] = field
return ret
def get_nested_field(self, model_field):
"""
Creates a default instance of a nested relational field.
"""
return ModelSerializer()
def get_related_field(self, model_field):
"""
Creates a default instance of a flat relational field.
"""
return PrimaryKeyRelatedField()
def get_field(self, model_field):
"""
Creates a default instance of a basic field.
"""
return Field()
def restore_object(self, attrs, instance=None):
"""
Restore the model instance.
"""
m2m_data = {}
for field in self.opts.model._meta.many_to_many:
if field.name in attrs:
m2m_data[field.name] = attrs.pop(field.name)
return DeserializedObject(self.opts.model(**attrs), m2m_data)

View File

@ -153,7 +153,7 @@ class TestFormParser(TestCase):
parser = FormParser()
stream = StringIO(self.string)
(data, files) = parser.parse(stream, {}, [])
(data, files) = parser.parse(stream)
self.assertEqual(Form(data).is_valid(), True)
@ -203,10 +203,10 @@ class TestXMLParser(TestCase):
def test_parse(self):
parser = XMLParser()
(data, files) = parser.parse(self._input, {}, [])
(data, files) = parser.parse(self._input)
self.assertEqual(data, self._data)
def test_complex_data_parse(self):
parser = XMLParser()
(data, files) = parser.parse(self._complex_data_input, {}, [])
(data, files) = parser.parse(self._complex_data_input)
self.assertEqual(data, self._complex_data)

View File

@ -380,7 +380,7 @@ class XMLRendererTestCase(TestCase):
content = StringIO(renderer.render(self._complex_data, 'application/xml'))
parser = XMLParser()
complex_data_out, dummy = parser.parse(content, {}, [])
complex_data_out, dummy = parser.parse(content)
error_msg = "complex data differs!IN:\n %s \n\n OUT:\n %s" % (repr(self._complex_data), repr(complex_data_out))
self.assertEqual(self._complex_data, complex_data_out, error_msg)