django-rest-framework/rest_framework/fields.py
2012-09-25 13:20:12 +01:00

450 lines
15 KiB
Python

import copy
import datetime
import inspect
import warnings
from django.core import validators
from django.core.exceptions import ValidationError
from django.conf import settings
from django.db import DEFAULT_DB_ALIAS
from django.db.models.related import RelatedObject
from django.utils.encoding import is_protected_type, smart_unicode
from django.utils.translation import ugettext_lazy as _
from rest_framework.compat import parse_date, parse_datetime
from rest_framework.compat import timezone
def is_simple_callable(obj):
"""
True if the object is a callable that takes no arguments.
"""
return (
(inspect.isfunction(obj) and not inspect.getargspec(obj)[0]) or
(inspect.ismethod(obj) and len(inspect.getargspec(obj)[0]) <= 1)
)
class Field(object):
creation_counter = 0
default_validators = []
default_error_messages = {
'required': _('This field is required.'),
'invalid': _('Invalid value.'),
}
empty = ''
def __init__(self, source=None, readonly=False, required=None,
validators=[], error_messages=None):
self.parent = None
self.creation_counter = Field.creation_counter
Field.creation_counter += 1
self.source = source
self.readonly = readonly
if required is None:
self.required = not(readonly)
else:
assert not readonly, "Cannot set required=True and readonly=True"
self.required = required
messages = {}
for c in reversed(self.__class__.__mro__):
messages.update(getattr(c, 'default_error_messages', {}))
messages.update(error_messages or {})
self.error_messages = messages
self.validators = self.default_validators + validators
def initialize(self, parent, model_field=None):
"""
Called to set up a field prior to field_to_native or field_from_native.
parent - The parent serializer.
model_field - The model field this field corrosponds to, if one exists.
"""
self.parent = parent
self.root = parent.root or parent
self.context = self.root.context
if model_field:
self.model_field = model_field
def validate(self, value):
if value in validators.EMPTY_VALUES and self.required:
raise ValidationError(self.error_messages['required'])
def run_validators(self, value):
if value in validators.EMPTY_VALUES:
return
errors = []
for v in self.validators:
try:
v(value)
except ValidationError as e:
if hasattr(e, 'code') and e.code in self.error_messages:
message = self.error_messages[e.code]
if e.params:
message = message % e.params
errors.append(message)
else:
errors.extend(e.messages)
if errors:
raise ValidationError(errors)
def field_from_native(self, data, field_name, into):
"""
Given a dictionary and a field name, updates the dictionary `into`,
with the field and it's deserialized value.
"""
if self.readonly:
return
try:
native = data[field_name]
except KeyError:
return # TODO Consider validation behaviour, 'required' opt etc...
value = self.from_native(native)
if self.source == '*':
if value:
into.update(value)
else:
self.validate(value)
self.run_validators(value)
into[self.source or field_name] = value
def from_native(self, value):
"""
Reverts a simple representation back to the field's value.
"""
if hasattr(self, 'model_field'):
try:
return self.model_field.rel.to._meta.get_field(self.model_field.rel.field_name).to_python(value)
except:
return self.model_field.to_python(value)
return value
def field_to_native(self, obj, field_name):
"""
Given and object and a field name, returns the value that should be
serialized for that field.
"""
if obj is None:
return self.empty
if self.source == '*':
return self.to_native(obj)
self.obj = obj # Need to hang onto this in the case of model fields
if hasattr(self, 'model_field'):
return self.to_native(self.model_field._get_val_from_obj(obj))
return self.to_native(getattr(obj, self.source or field_name))
def to_native(self, value):
"""
Converts the field's value into it's simple representation.
"""
if is_simple_callable(value):
value = value()
if is_protected_type(value):
return value
elif hasattr(self, 'model_field'):
return self.model_field.value_to_string(self.obj)
return smart_unicode(value)
def attributes(self):
"""
Returns a dictionary of attributes to be used when serializing to xml.
"""
try:
return {
"type": self.model_field.get_internal_type()
}
except AttributeError:
return {}
class RelatedField(Field):
"""
A base class for model related fields or related managers.
Subclass this and override `convert` to define custom behaviour when
serializing related objects.
"""
def field_to_native(self, obj, field_name):
obj = getattr(obj, field_name)
if obj.__class__.__name__ in ('RelatedManager', 'ManyRelatedManager'):
return [self.to_native(item) for item in obj.all()]
return self.to_native(obj)
def attributes(self):
try:
return {
"rel": self.model_field.rel.__class__.__name__,
"to": smart_unicode(self.model_field.rel.to._meta)
}
except AttributeError:
return {}
class PrimaryKeyRelatedField(RelatedField):
"""
Serializes a model related field or related manager to a pk value.
"""
# Note the we use ModelRelatedField's implementation, as we want to get the
# raw database value directly, since that won't involve another
# database lookup.
#
# An alternative implementation would simply be this...
#
# class PrimaryKeyRelatedField(RelatedField):
# def to_native(self, obj):
# return obj.pk
def to_native(self, pk):
"""
Simply returns the object's pk. You can subclass this method to
provide different serialization behavior of the pk.
(For example returning a URL based on the model's pk.)
"""
return pk
def field_to_native(self, obj, field_name):
try:
obj = obj.serializable_value(field_name)
except AttributeError:
field = obj._meta.get_field_by_name(field_name)[0]
obj = getattr(obj, field_name)
if obj.__class__.__name__ == 'RelatedManager':
return [self.to_native(item.pk) for item in obj.all()]
elif isinstance(field, RelatedObject):
return self.to_native(obj.pk)
raise
if obj.__class__.__name__ == 'ManyRelatedManager':
return [self.to_native(item.pk) for item in obj.all()]
return self.to_native(obj)
def field_from_native(self, data, field_name, into):
value = data.get(field_name)
if hasattr(value, '__iter__'):
into[field_name] = [self.from_native(item) for item in value]
else:
into[field_name + '_id'] = self.from_native(value)
class NaturalKeyRelatedField(RelatedField):
"""
Serializes a model related field or related manager to a natural key value.
"""
is_natural_key = True # XML renderer handles these differently
def to_native(self, obj):
if hasattr(obj, 'natural_key'):
return obj.natural_key()
return obj
def field_from_native(self, data, field_name, into):
value = data.get(field_name)
into[self.model_field.attname] = self.from_native(value)
def from_native(self, value):
# TODO: Support 'using' : db = options.pop('using', DEFAULT_DB_ALIAS)
manager = self.model_field.rel.to._default_manager
manager = manager.db_manager(DEFAULT_DB_ALIAS)
return manager.get_by_natural_key(*value).pk
class BooleanField(Field):
default_error_messages = {
'invalid': _(u"'%s' value must be either True or False."),
}
def from_native(self, value):
if value in (True, False):
# if value is 1 or 0 than it's equal to True or False, but we want
# to return a true bool for semantic reasons.
return bool(value)
if value in ('t', 'True', '1'):
return True
if value in ('f', 'False', '0'):
return False
raise ValidationError(self.error_messages['invalid'] % value)
class CharField(Field):
def __init__(self, max_length=None, min_length=None, *args, **kwargs):
self.max_length, self.min_length = max_length, min_length
super(CharField, self).__init__(*args, **kwargs)
if min_length is not None:
self.validators.append(validators.MinLengthValidator(min_length))
if max_length is not None:
self.validators.append(validators.MaxLengthValidator(max_length))
def from_native(self, value):
if isinstance(value, basestring) or value is None:
return value
return smart_unicode(value)
class EmailField(CharField):
default_error_messages = {
'invalid': _('Enter a valid e-mail address.'),
}
default_validators = [validators.validate_email]
def from_native(self, value):
return super(EmailField, self).from_native(value).strip()
def __deepcopy__(self, memo):
result = copy.copy(self)
memo[id(self)] = result
#result.widget = copy.deepcopy(self.widget, memo)
result.validators = self.validators[:]
return result
class DateField(Field):
default_error_messages = {
'invalid': _(u"'%s' value has an invalid date format. It must be "
u"in YYYY-MM-DD format."),
'invalid_date': _(u"'%s' value has the correct format (YYYY-MM-DD) "
u"but it is an invalid date."),
}
empty = None
def from_native(self, value):
if value is None:
return value
if isinstance(value, datetime.datetime):
if timezone and settings.USE_TZ and timezone.is_aware(value):
# Convert aware datetimes to the default time zone
# before casting them to dates (#17742).
default_timezone = timezone.get_default_timezone()
value = timezone.make_naive(value, default_timezone)
return value.date()
if isinstance(value, datetime.date):
return value
try:
parsed = parse_date(value)
if parsed is not None:
return parsed
except ValueError:
msg = self.error_messages['invalid_date'] % value
raise ValidationError(msg)
msg = self.error_messages['invalid'] % value
raise ValidationError(msg)
class DateTimeField(Field):
default_error_messages = {
'invalid': _(u"'%s' value has an invalid format. It must be in "
u"YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ] format."),
'invalid_date': _(u"'%s' value has the correct format "
u"(YYYY-MM-DD) but it is an invalid date."),
'invalid_datetime': _(u"'%s' value has the correct format "
u"(YYYY-MM-DD HH:MM[:ss[.uuuuuu]][TZ]) "
u"but it is an invalid date/time."),
}
empty = None
def from_native(self, value):
if value is None:
return value
if isinstance(value, datetime.datetime):
return value
if isinstance(value, datetime.date):
value = datetime.datetime(value.year, value.month, value.day)
if settings.USE_TZ:
# For backwards compatibility, interpret naive datetimes in
# local time. This won't work during DST change, but we can't
# do much about it, so we let the exceptions percolate up the
# call stack.
warnings.warn(u"DateTimeField received a naive datetime (%s)"
u" while time zone support is active." % value,
RuntimeWarning)
default_timezone = timezone.get_default_timezone()
value = timezone.make_aware(value, default_timezone)
return value
try:
parsed = parse_datetime(value)
if parsed is not None:
return parsed
except ValueError:
msg = self.error_messages['invalid_datetime'] % value
raise ValidationError(msg)
try:
parsed = parse_date(value)
if parsed is not None:
return datetime.datetime(parsed.year, parsed.month, parsed.day)
except ValueError:
msg = self.error_messages['invalid_date'] % value
raise ValidationError(msg)
msg = self.error_messages['invalid'] % value
raise ValidationError(msg)
class IntegerField(Field):
default_error_messages = {
'invalid': _('Enter a whole number.'),
'max_value': _('Ensure this value is less than or equal to %(limit_value)s.'),
'min_value': _('Ensure this value is greater than or equal to %(limit_value)s.'),
}
def __init__(self, max_value=None, min_value=None, *args, **kwargs):
self.max_value, self.min_value = max_value, min_value
super(IntegerField, self).__init__(*args, **kwargs)
if max_value is not None:
self.validators.append(validators.MaxValueValidator(max_value))
if min_value is not None:
self.validators.append(validators.MinValueValidator(min_value))
def from_native(self, value):
if value in validators.EMPTY_VALUES:
return None
try:
value = int(str(value))
except (ValueError, TypeError):
raise ValidationError(self.error_messages['invalid'])
return value
class FloatField(Field):
default_error_messages = {
'invalid': _("'%s' value must be a float."),
}
def from_native(self, value):
if value is None:
return value
try:
return float(value)
except (TypeError, ValueError):
msg = self.error_messages['invalid'] % value
raise ValidationError(msg)
# field_mapping = {
# models.AutoField: IntegerField,
# models.BooleanField: BooleanField,
# models.CharField: CharField,
# models.DateTimeField: DateTimeField,
# models.DateField: DateField,
# models.BigIntegerField: IntegerField,
# models.IntegerField: IntegerField,
# models.PositiveIntegerField: IntegerField,
# models.FloatField: FloatField
# }
# def modelfield_to_serializerfield(field):
# return field_mapping.get(type(field), Field)