django-rest-framework/rest_framework/fields.py
2014-09-11 13:20:44 +01:00

815 lines
27 KiB
Python

from django.conf import settings
from django.core import validators
from django.core.exceptions import ValidationError
from django.utils import timezone
from django.utils.dateparse import parse_date, parse_datetime, parse_time
from django.utils.encoding import is_protected_type
from django.utils.translation import ugettext_lazy as _
from rest_framework import ISO_8601
from rest_framework.compat import smart_text
from rest_framework.settings import api_settings
from rest_framework.utils import html, representation, humanize_datetime
import datetime
import decimal
import inspect
import warnings
class empty:
"""
This class is used to represent no data being provided for a given input
or output value.
It is required because `None` may be a valid input or output value.
"""
pass
def is_simple_callable(obj):
"""
True if the object is a callable that takes no arguments.
"""
function = inspect.isfunction(obj)
method = inspect.ismethod(obj)
if not (function or method):
return False
args, _, _, defaults = inspect.getargspec(obj)
len_args = len(args) if function else len(args) - 1
len_defaults = len(defaults) if defaults else 0
return len_args <= len_defaults
def get_attribute(instance, attrs):
"""
Similar to Python's built in `getattr(instance, attr)`,
but takes a list of nested attributes, instead of a single attribute.
Also accepts either attribute lookup on objects or dictionary lookups.
"""
for attr in attrs:
try:
instance = getattr(instance, attr)
except AttributeError as exc:
try:
return instance[attr]
except (KeyError, TypeError):
raise exc
return instance
def set_value(dictionary, keys, value):
"""
Similar to Python's built in `dictionary[key] = value`,
but takes a list of nested keys instead of a single key.
set_value({'a': 1}, [], {'b': 2}) -> {'a': 1, 'b': 2}
set_value({'a': 1}, ['x'], 2) -> {'a': 1, 'x': 2}
set_value({'a': 1}, ['x', 'y'], 2) -> {'a': 1, 'x': {'y': 2}}
"""
if not keys:
dictionary.update(value)
return
for key in keys[:-1]:
if key not in dictionary:
dictionary[key] = {}
dictionary = dictionary[key]
dictionary[keys[-1]] = value
class SkipField(Exception):
pass
NOT_READ_ONLY_WRITE_ONLY = 'May not set both `read_only` and `write_only`'
NOT_READ_ONLY_REQUIRED = 'May not set both `read_only` and `required`'
NOT_READ_ONLY_DEFAULT = 'May not set both `read_only` and `default`'
NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`'
MISSING_ERROR_MESSAGE = (
'ValidationError raised by `{class_name}`, but error key `{key}` does '
'not exist in the `error_messages` dictionary.'
)
class Field(object):
_creation_counter = 0
default_error_messages = {
'required': _('This field is required.')
}
default_validators = []
def __init__(self, read_only=False, write_only=False,
required=None, default=empty, initial=None, source=None,
label=None, help_text=None, style=None,
error_messages=None, validators=[]):
self._creation_counter = Field._creation_counter
Field._creation_counter += 1
# If `required` is unset, then use `True` unless a default is provided.
if required is None:
required = default is empty and not read_only
# Some combinations of keyword arguments do not make sense.
assert not (read_only and write_only), NOT_READ_ONLY_WRITE_ONLY
assert not (read_only and required), NOT_READ_ONLY_REQUIRED
assert not (read_only and default is not empty), NOT_READ_ONLY_DEFAULT
assert not (required and default is not empty), NOT_REQUIRED_DEFAULT
self.read_only = read_only
self.write_only = write_only
self.required = required
self.default = default
self.source = source
self.initial = initial
self.label = label
self.help_text = help_text
self.style = {} if style is None else style
self.validators = validators or self.default_validators[:]
# Collect default error message from self and parent classes
messages = {}
for cls in reversed(self.__class__.__mro__):
messages.update(getattr(cls, 'default_error_messages', {}))
messages.update(error_messages or {})
self.error_messages = messages
def bind(self, field_name, parent, root):
"""
Setup the context for the field instance.
"""
self.field_name = field_name
self.parent = parent
self.root = root
self.context = parent.context
# `self.label` should deafult to being based on the field name.
if self.label is None:
self.label = self.field_name.replace('_', ' ').capitalize()
# self.source should default to being the same as the field name.
if self.source is None:
self.source = field_name
# self.source_attrs is a list of attributes that need to be looked up
# when serializing the instance, or populating the validated data.
if self.source == '*':
self.source_attrs = []
else:
self.source_attrs = self.source.split('.')
def get_initial(self):
"""
Return a value to use when the field is being returned as a primative
value, without any object instance.
"""
return self.initial
def get_value(self, dictionary):
"""
Given the *incoming* primative data, return the value for this field
that should be validated and transformed to a native value.
"""
return dictionary.get(self.field_name, empty)
def get_attribute(self, instance):
"""
Given the *outgoing* object instance, return the value for this field
that should be returned as a primative value.
"""
return get_attribute(instance, self.source_attrs)
def get_default(self):
"""
Return the default value to use when validating data if no input
is provided for this field.
If a default has not been set for this field then this will simply
return `empty`, indicating that no value should be set in the
validated data for this field.
"""
if self.default is empty:
raise SkipField()
return self.default
def validate_value(self, data=empty):
"""
Validate a simple representation and return the internal value.
The provided data may be `empty` if no representation was included.
May return `empty` if the field should not be included in the
validated data.
"""
if data is empty:
if self.required:
self.fail('required')
return self.get_default()
value = self.to_native(data)
self.run_validators(value)
return value
def run_validators(self, value):
if value in validators.EMPTY_VALUES:
return
errors = []
for validator in self.validators:
try:
validator(value)
except ValidationError as exc:
errors.extend(exc.messages)
if errors:
raise ValidationError(errors)
def to_native(self, data):
"""
Transform the *incoming* primative data into a native value.
"""
return data
def to_primative(self, value):
"""
Transform the *outgoing* native value into primative data.
"""
return value
def fail(self, key, **kwargs):
"""
A helper method that simply raises a validation error.
"""
try:
msg = self.error_messages[key]
except KeyError:
class_name = self.__class__.__name__
msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key)
raise AssertionError(msg)
raise ValidationError(msg.format(**kwargs))
def __new__(cls, *args, **kwargs):
"""
When a field is instantiated, we store the arguments that were used,
so that we can present a helpful representation of the object.
"""
instance = super(Field, cls).__new__(cls)
instance._args = args
instance._kwargs = kwargs
return instance
def __repr__(self):
return representation.field_repr(self)
# Boolean types...
class BooleanField(Field):
default_error_messages = {
'invalid': _('`{input}` is not a valid boolean.')
}
TRUE_VALUES = set(('t', 'T', 'true', 'True', 'TRUE', '1', 1, True))
FALSE_VALUES = set(('f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False))
def get_value(self, dictionary):
if html.is_html_input(dictionary):
# HTML forms do not send a `False` value on an empty checkbox,
# so we override the default empty value to be False.
return dictionary.get(self.field_name, False)
return dictionary.get(self.field_name, empty)
def to_native(self, data):
if data in self.TRUE_VALUES:
return True
elif data in self.FALSE_VALUES:
return False
self.fail('invalid', input=data)
def to_primative(self, value):
if value is None:
return None
if value in self.TRUE_VALUES:
return True
elif value in self.FALSE_VALUES:
return False
return bool(value)
# String types...
class CharField(Field):
default_error_messages = {
'blank': _('This field may not be blank.')
}
def __init__(self, **kwargs):
self.allow_blank = kwargs.pop('allow_blank', False)
self.max_length = kwargs.pop('max_length', None)
self.min_length = kwargs.pop('min_length', None)
super(CharField, self).__init__(**kwargs)
def to_native(self, data):
if data == '' and not self.allow_blank:
self.fail('blank')
return str(data)
def to_primative(self, value):
if value is None:
return None
return str(value)
class EmailField(CharField):
default_error_messages = {
'invalid': _('Enter a valid email address.')
}
default_validators = [validators.validate_email]
def to_native(self, data):
ret = super(EmailField, self).to_native(data)
if ret is None:
return None
return ret.strip()
def to_primative(self, value):
ret = super(EmailField, self).to_primative(value)
if ret is None:
return None
return ret.strip()
class RegexField(CharField):
def __init__(self, regex, **kwargs):
kwargs['validators'] = (
[validators.RegexValidator(regex)] +
kwargs.get('validators', [])
)
super(RegexField, self).__init__(**kwargs)
class SlugField(CharField):
default_error_messages = {
'invalid': _("Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens.")
}
default_validators = [validators.validate_slug]
class URLField(CharField):
default_error_messages = {
'invalid': _("Enter a valid URL.")
}
default_validators = [validators.URLValidator()]
# Number types...
class IntegerField(Field):
default_error_messages = {
'invalid': _('A valid integer is required.')
}
def __init__(self, **kwargs):
max_value = kwargs.pop('max_value', None)
min_value = kwargs.pop('min_value', None)
super(IntegerField, self).__init__(**kwargs)
if max_value is not None:
self.validators.append(validators.MaxValueValidator(max_value))
if min_value is not None:
self.validators.append(validators.MinValueValidator(min_value))
def to_native(self, data):
try:
data = int(str(data))
except (ValueError, TypeError):
self.fail('invalid')
return data
def to_primative(self, value):
if value is None:
return None
return int(value)
class FloatField(Field):
default_error_messages = {
'invalid': _("'%s' value must be a float."),
}
def __init__(self, **kwargs):
max_value = kwargs.pop('max_value', None)
min_value = kwargs.pop('min_value', None)
super(FloatField, self).__init__(**kwargs)
if max_value is not None:
self.validators.append(validators.MaxValueValidator(max_value))
if min_value is not None:
self.validators.append(validators.MinValueValidator(min_value))
def to_primative(self, value):
if value is None:
return None
try:
return float(value)
except (TypeError, ValueError):
self.fail('invalid', value=value)
def to_native(self, value):
if value is None:
return None
return float(value)
class DecimalField(Field):
default_error_messages = {
'invalid': _('Enter a number.'),
'max_value': _('Ensure this value is less than or equal to {max_value}.'),
'min_value': _('Ensure this value is greater than or equal to {min_value}.'),
'max_digits': _('Ensure that there are no more than {max_digits} digits in total.'),
'max_decimal_places': _('Ensure that there are no more than {max_decimal_places} decimal places.'),
'max_whole_digits': _('Ensure that there are no more than {max_whole_digits} digits before the decimal point.')
}
def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, **kwargs):
self.max_value, self.min_value = max_value, min_value
self.max_digits, self.max_decimal_places = max_digits, decimal_places
super(DecimalField, self).__init__(**kwargs)
if max_value is not None:
self.validators.append(validators.MaxValueValidator(max_value))
if min_value is not None:
self.validators.append(validators.MinValueValidator(min_value))
def from_native(self, value):
"""
Validates that the input is a decimal number. Returns a Decimal
instance. Returns None for empty values. Ensures that there are no more
than max_digits in the number, and no more than decimal_places digits
after the decimal point.
"""
if value in validators.EMPTY_VALUES:
return None
value = smart_text(value).strip()
try:
value = decimal.Decimal(value)
except decimal.DecimalException:
self.fail('invalid')
# Check for NaN. It is the only value that isn't equal to itself,
# so we can use this to identify NaN values.
if value != value:
self.fail('invalid')
# Check for infinity and negative infinity.
if value in (decimal.Decimal('Inf'), decimal.Decimal('-Inf')):
self.fail('invalid')
sign, digittuple, exponent = value.as_tuple()
decimals = abs(exponent)
# digittuple doesn't include any leading zeros.
digits = len(digittuple)
if decimals > digits:
# We have leading zeros up to or past the decimal point. Count
# everything past the decimal point as a digit. We do not count
# 0 before the decimal point as a digit since that would mean
# we would not allow max_digits = decimal_places.
digits = decimals
whole_digits = digits - decimals
if self.max_digits is not None and digits > self.max_digits:
self.fail('max_digits', max_digits=self.max_digits)
if self.decimal_places is not None and decimals > self.decimal_places:
self.fail('max_decimal_places', max_decimal_places=self.max_decimal_places)
if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places):
self.fail('max_whole_digits', max_while_digits=self.max_digits - self.decimal_places)
return value
# Date & time fields...
class DateField(Field):
default_error_messages = {
'invalid': _("Date has wrong format. Use one of these formats instead: %s"),
}
input_formats = api_settings.DATE_INPUT_FORMATS
format = api_settings.DATE_FORMAT
def __init__(self, input_formats=None, format=None, *args, **kwargs):
self.input_formats = input_formats if input_formats is not None else self.input_formats
self.format = format if format is not None else self.format
super(DateField, self).__init__(*args, **kwargs)
def from_native(self, value):
if value in validators.EMPTY_VALUES:
return None
if isinstance(value, datetime.datetime):
if timezone and settings.USE_TZ and timezone.is_aware(value):
# Convert aware datetimes to the default time zone
# before casting them to dates (#17742).
default_timezone = timezone.get_default_timezone()
value = timezone.make_naive(value, default_timezone)
return value.date()
if isinstance(value, datetime.date):
return value
for format in self.input_formats:
if format.lower() == ISO_8601:
try:
parsed = parse_date(value)
except (ValueError, TypeError):
pass
else:
if parsed is not None:
return parsed
else:
try:
parsed = datetime.datetime.strptime(value, format)
except (ValueError, TypeError):
pass
else:
return parsed.date()
humanized_format = humanize_datetime.date_formats(self.input_formats)
msg = self.error_messages['invalid'] % humanized_format
raise ValidationError(msg)
def to_primative(self, value):
if value is None or self.format is None:
return value
if isinstance(value, datetime.datetime):
value = value.date()
if self.format.lower() == ISO_8601:
return value.isoformat()
return value.strftime(self.format)
class DateTimeField(Field):
default_error_messages = {
'invalid': _("Datetime has wrong format. Use one of these formats instead: %s"),
}
input_formats = api_settings.DATETIME_INPUT_FORMATS
format = api_settings.DATETIME_FORMAT
def __init__(self, input_formats=None, format=None, *args, **kwargs):
self.input_formats = input_formats if input_formats is not None else self.input_formats
self.format = format if format is not None else self.format
super(DateTimeField, self).__init__(*args, **kwargs)
def from_native(self, value):
if value in validators.EMPTY_VALUES:
return None
if isinstance(value, datetime.datetime):
return value
if isinstance(value, datetime.date):
value = datetime.datetime(value.year, value.month, value.day)
if settings.USE_TZ:
# For backwards compatibility, interpret naive datetimes in
# local time. This won't work during DST change, but we can't
# do much about it, so we let the exceptions percolate up the
# call stack.
warnings.warn("DateTimeField received a naive datetime (%s)"
" while time zone support is active." % value,
RuntimeWarning)
default_timezone = timezone.get_default_timezone()
value = timezone.make_aware(value, default_timezone)
return value
for format in self.input_formats:
if format.lower() == ISO_8601:
try:
parsed = parse_datetime(value)
except (ValueError, TypeError):
pass
else:
if parsed is not None:
return parsed
else:
try:
parsed = datetime.datetime.strptime(value, format)
except (ValueError, TypeError):
pass
else:
return parsed
humanized_format = humanize_datetime.datetime_formats(self.input_formats)
msg = self.error_messages['invalid'] % humanized_format
raise ValidationError(msg)
def to_primative(self, value):
if value is None or self.format is None:
return value
if self.format.lower() == ISO_8601:
ret = value.isoformat()
if ret.endswith('+00:00'):
ret = ret[:-6] + 'Z'
return ret
return value.strftime(self.format)
class TimeField(Field):
default_error_messages = {
'invalid': _("Time has wrong format. Use one of these formats instead: %s"),
}
input_formats = api_settings.TIME_INPUT_FORMATS
format = api_settings.TIME_FORMAT
def __init__(self, input_formats=None, format=None, *args, **kwargs):
self.input_formats = input_formats if input_formats is not None else self.input_formats
self.format = format if format is not None else self.format
super(TimeField, self).__init__(*args, **kwargs)
def from_native(self, value):
if value in validators.EMPTY_VALUES:
return None
if isinstance(value, datetime.time):
return value
for format in self.input_formats:
if format.lower() == ISO_8601:
try:
parsed = parse_time(value)
except (ValueError, TypeError):
pass
else:
if parsed is not None:
return parsed
else:
try:
parsed = datetime.datetime.strptime(value, format)
except (ValueError, TypeError):
pass
else:
return parsed.time()
humanized_format = humanize_datetime.time_formats(self.input_formats)
msg = self.error_messages['invalid'] % humanized_format
raise ValidationError(msg)
def to_primative(self, value):
if value is None or self.format is None:
return value
if isinstance(value, datetime.datetime):
value = value.time()
if self.format.lower() == ISO_8601:
return value.isoformat()
return value.strftime(self.format)
# Choice types...
class ChoiceField(Field):
default_error_messages = {
'invalid_choice': _('`{input}` is not a valid choice.')
}
def __init__(self, choices, **kwargs):
# Allow either single or paired choices style:
# choices = [1, 2, 3]
# choices = [(1, 'First'), (2, 'Second'), (3, 'Third')]
pairs = [
isinstance(item, (list, tuple)) and len(item) == 2
for item in choices
]
if all(pairs):
self.choices = dict([(key, display_value) for key, display_value in choices])
else:
self.choices = dict([(item, item) for item in choices])
# Map the string representation of choices to the underlying value.
# Allows us to deal with eg. integer choices while supporting either
# integer or string input, but still get the correct datatype out.
self.choice_strings_to_values = dict([
(str(key), key) for key in self.choices.keys()
])
super(ChoiceField, self).__init__(**kwargs)
def to_native(self, data):
try:
return self.choice_strings_to_values[str(data)]
except KeyError:
self.fail('invalid_choice', input=data)
def to_primative(self, value):
return value
class MultipleChoiceField(ChoiceField):
default_error_messages = {
'invalid_choice': _('`{input}` is not a valid choice.'),
'not_a_list': _('Expected a list of items but got type `{input_type}`')
}
def to_native(self, data):
if not hasattr(data, '__iter__'):
self.fail('not_a_list', input_type=type(data).__name__)
return set([
super(MultipleChoiceField, self).to_native(item)
for item in data
])
def to_primative(self, value):
return value
# File types...
class FileField(Field):
pass # TODO
class ImageField(Field):
pass # TODO
# Advanced field types...
class ReadOnlyField(Field):
"""
A read-only field that simply returns the field value.
If the field is a method with no parameters, the method will be called
and it's return value used as the representation.
For example, the following would call `get_expiry_date()` on the object:
class ExampleSerializer(self):
expiry_date = ReadOnlyField(source='get_expiry_date')
"""
def __init__(self, **kwargs):
kwargs['read_only'] = True
super(ReadOnlyField, self).__init__(**kwargs)
def to_native(self, data):
raise NotImplemented('.to_native() not supported.')
def to_primative(self, value):
if is_simple_callable(value):
return value()
return value
class MethodField(Field):
"""
A read-only field that get its representation from calling a method on the
parent serializer class. The method called will be of the form
"get_{field_name}", and should take a single argument, which is the
object being serialized.
For example:
class ExampleSerializer(self):
extra_info = MethodField()
def get_extra_info(self, obj):
return ... # Calculate some data to return.
"""
def __init__(self, **kwargs):
kwargs['source'] = '*'
kwargs['read_only'] = True
super(MethodField, self).__init__(**kwargs)
def to_native(self, data):
raise NotImplemented('.to_native() not supported.')
def to_primative(self, value):
attr = 'get_{field_name}'.format(field_name=self.field_name)
method = getattr(self.parent, attr)
return method(value)
class ModelField(Field):
"""
A generic field that can be used against an arbitrary model field.
This is used by `ModelSerializer` when dealing with custom model fields,
that do not have a serializer field to be mapped to.
"""
def __init__(self, model_field, **kwargs):
self.model_field = model_field
kwargs['source'] = '*'
super(ModelField, self).__init__(**kwargs)
def to_native(self, data):
rel = getattr(self.model_field, 'rel', None)
if rel is not None:
return rel.to._meta.get_field(rel.field_name).to_python(data)
return self.model_field.to_python(data)
def to_primative(self, obj):
value = self.model_field._get_val_from_obj(obj)
if is_protected_type(value):
return value
return self.model_field.value_to_string(obj)