mirror of
https://github.com/encode/django-rest-framework.git
synced 2024-11-26 03:23:59 +03:00
Fleshing out serializer fields
This commit is contained in:
parent
21980b800d
commit
b1c07670ca
|
@ -1,8 +1,18 @@
|
|||
from django.conf import settings
|
||||
from django.core import validators
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.utils import timezone
|
||||
from django.utils.dateparse import parse_date, parse_datetime, parse_time
|
||||
from django.utils.encoding import is_protected_type
|
||||
from rest_framework.utils import html
|
||||
from django.utils.translation import ugettext_lazy as _
|
||||
from rest_framework import ISO_8601
|
||||
from rest_framework.compat import smart_text
|
||||
from rest_framework.settings import api_settings
|
||||
from rest_framework.utils import html, representation, humanize_datetime
|
||||
import datetime
|
||||
import decimal
|
||||
import inspect
|
||||
import warnings
|
||||
|
||||
|
||||
class empty:
|
||||
|
@ -71,22 +81,22 @@ class SkipField(Exception):
|
|||
pass
|
||||
|
||||
|
||||
NOT_READ_ONLY_WRITE_ONLY = 'May not set both `read_only` and `write_only`'
|
||||
NOT_READ_ONLY_REQUIRED = 'May not set both `read_only` and `required`'
|
||||
NOT_READ_ONLY_DEFAULT = 'May not set both `read_only` and `default`'
|
||||
NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`'
|
||||
MISSING_ERROR_MESSAGE = (
|
||||
'ValidationError raised by `{class_name}`, but error key `{key}` does '
|
||||
'not exist in the `error_messages` dictionary.'
|
||||
)
|
||||
|
||||
|
||||
class Field(object):
|
||||
_creation_counter = 0
|
||||
|
||||
MESSAGES = {
|
||||
'required': 'This field is required.'
|
||||
default_error_messages = {
|
||||
'required': _('This field is required.')
|
||||
}
|
||||
|
||||
_NOT_READ_ONLY_WRITE_ONLY = 'May not set both `read_only` and `write_only`'
|
||||
_NOT_READ_ONLY_REQUIRED = 'May not set both `read_only` and `required`'
|
||||
_NOT_READ_ONLY_DEFAULT = 'May not set both `read_only` and `default`'
|
||||
_NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`'
|
||||
_MISSING_ERROR_MESSAGE = (
|
||||
'ValidationError raised by `{class_name}`, but error key `{key}` does '
|
||||
'not exist in the `MESSAGES` dictionary.'
|
||||
)
|
||||
|
||||
default_validators = []
|
||||
|
||||
def __init__(self, read_only=False, write_only=False,
|
||||
|
@ -100,10 +110,10 @@ class Field(object):
|
|||
required = default is empty and not read_only
|
||||
|
||||
# Some combinations of keyword arguments do not make sense.
|
||||
assert not (read_only and write_only), self._NOT_READ_ONLY_WRITE_ONLY
|
||||
assert not (read_only and required), self._NOT_READ_ONLY_REQUIRED
|
||||
assert not (read_only and default is not empty), self._NOT_READ_ONLY_DEFAULT
|
||||
assert not (required and default is not empty), self._NOT_REQUIRED_DEFAULT
|
||||
assert not (read_only and write_only), NOT_READ_ONLY_WRITE_ONLY
|
||||
assert not (read_only and required), NOT_READ_ONLY_REQUIRED
|
||||
assert not (read_only and default is not empty), NOT_READ_ONLY_DEFAULT
|
||||
assert not (required and default is not empty), NOT_REQUIRED_DEFAULT
|
||||
|
||||
self.read_only = read_only
|
||||
self.write_only = write_only
|
||||
|
@ -113,7 +123,14 @@ class Field(object):
|
|||
self.initial = initial
|
||||
self.label = label
|
||||
self.style = {} if style is None else style
|
||||
self.validators = self.default_validators + validators
|
||||
self.validators = validators or self.default_validators[:]
|
||||
|
||||
# Collect default error message from self and parent classes
|
||||
messages = {}
|
||||
for cls in reversed(self.__class__.__mro__):
|
||||
messages.update(getattr(cls, 'default_error_messages', {}))
|
||||
messages.update(error_messages or {})
|
||||
self.error_messages = messages
|
||||
|
||||
def bind(self, field_name, parent, root):
|
||||
"""
|
||||
|
@ -186,12 +203,14 @@ class Field(object):
|
|||
self.fail('required')
|
||||
return self.get_default()
|
||||
|
||||
self.run_validators(data)
|
||||
return self.to_native(data)
|
||||
value = self.to_native(data)
|
||||
self.run_validators(value)
|
||||
return value
|
||||
|
||||
def run_validators(self, value):
|
||||
if value in validators.EMPTY_VALUES:
|
||||
return
|
||||
|
||||
errors = []
|
||||
for validator in self.validators:
|
||||
try:
|
||||
|
@ -218,33 +237,32 @@ class Field(object):
|
|||
A helper method that simply raises a validation error.
|
||||
"""
|
||||
try:
|
||||
raise ValidationError(self.MESSAGES[key].format(**kwargs))
|
||||
msg = self.error_messages[key]
|
||||
except KeyError:
|
||||
class_name = self.__class__.__name__
|
||||
msg = self._MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key)
|
||||
msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key)
|
||||
raise AssertionError(msg)
|
||||
raise ValidationError(msg.format(**kwargs))
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
"""
|
||||
When a field is instantiated, we store the arguments that were used,
|
||||
so that we can present a helpful representation of the object.
|
||||
"""
|
||||
instance = super(Field, cls).__new__(cls)
|
||||
instance._args = args
|
||||
instance._kwargs = kwargs
|
||||
return instance
|
||||
|
||||
def __repr__(self):
|
||||
arg_string = ', '.join([repr(val) for val in self._args])
|
||||
kwarg_string = ', '.join([
|
||||
'%s=%s' % (key, repr(val)) for key, val in self._kwargs.items()
|
||||
])
|
||||
if arg_string and kwarg_string:
|
||||
arg_string += ', '
|
||||
class_name = self.__class__.__name__
|
||||
return "%s(%s%s)" % (class_name, arg_string, kwarg_string)
|
||||
return representation.field_repr(self)
|
||||
|
||||
|
||||
# Boolean types...
|
||||
|
||||
class BooleanField(Field):
|
||||
MESSAGES = {
|
||||
'required': 'This field is required.',
|
||||
'invalid_value': '`{input}` is not a valid boolean.'
|
||||
default_error_messages = {
|
||||
'invalid': _('`{input}` is not a valid boolean.')
|
||||
}
|
||||
TRUE_VALUES = {'t', 'T', 'true', 'True', 'TRUE', '1', 1, True}
|
||||
FALSE_VALUES = {'f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False}
|
||||
|
@ -261,13 +279,23 @@ class BooleanField(Field):
|
|||
return True
|
||||
elif data in self.FALSE_VALUES:
|
||||
return False
|
||||
self.fail('invalid_value', input=data)
|
||||
self.fail('invalid', input=data)
|
||||
|
||||
def to_primative(self, value):
|
||||
if value is None:
|
||||
return None
|
||||
if value in self.TRUE_VALUES:
|
||||
return True
|
||||
elif value in self.FALSE_VALUES:
|
||||
return False
|
||||
return bool(value)
|
||||
|
||||
|
||||
# String types...
|
||||
|
||||
class CharField(Field):
|
||||
MESSAGES = {
|
||||
'required': 'This field is required.',
|
||||
'blank': 'This field may not be blank.'
|
||||
default_error_messages = {
|
||||
'blank': _('This field may not be blank.')
|
||||
}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
|
@ -281,19 +309,364 @@ class CharField(Field):
|
|||
self.fail('blank')
|
||||
return str(data)
|
||||
|
||||
def to_primative(self, value):
|
||||
if value is None:
|
||||
return None
|
||||
return str(value)
|
||||
|
||||
class ChoiceField(Field):
|
||||
MESSAGES = {
|
||||
'required': 'This field is required.',
|
||||
'invalid_choice': '`{input}` is not a valid choice.'
|
||||
|
||||
class EmailField(CharField):
|
||||
default_error_messages = {
|
||||
'invalid': _('Enter a valid email address.')
|
||||
}
|
||||
default_validators = [validators.validate_email]
|
||||
|
||||
def to_native(self, data):
|
||||
ret = super(EmailField, self).to_native(data)
|
||||
if ret is None:
|
||||
return None
|
||||
return ret.strip()
|
||||
|
||||
def to_primative(self, value):
|
||||
ret = super(EmailField, self).to_primative(value)
|
||||
if ret is None:
|
||||
return None
|
||||
return ret.strip()
|
||||
|
||||
|
||||
class RegexField(CharField):
|
||||
def __init__(self, regex, **kwargs):
|
||||
kwargs['validators'] = (
|
||||
[validators.RegexValidator(regex)] +
|
||||
kwargs.get('validators', [])
|
||||
)
|
||||
super(RegexField, self).__init__(**kwargs)
|
||||
|
||||
|
||||
class SlugField(CharField):
|
||||
default_error_messages = {
|
||||
'invalid': _("Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens.")
|
||||
}
|
||||
default_validators = [validators.validate_slug]
|
||||
|
||||
|
||||
class URLField(CharField):
|
||||
default_error_messages = {
|
||||
'invalid': _("Enter a valid URL.")
|
||||
}
|
||||
default_validators = [validators.URLValidator()]
|
||||
|
||||
|
||||
# Number types...
|
||||
|
||||
class IntegerField(Field):
|
||||
default_error_messages = {
|
||||
'invalid': _('A valid integer is required.')
|
||||
}
|
||||
coerce_to_type = str
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
choices = kwargs.pop('choices')
|
||||
max_value = kwargs.pop('max_value', None)
|
||||
min_value = kwargs.pop('min_value', None)
|
||||
super(IntegerField, self).__init__(**kwargs)
|
||||
if max_value is not None:
|
||||
self.validators.append(validators.MaxValueValidator(max_value))
|
||||
if min_value is not None:
|
||||
self.validators.append(validators.MinValueValidator(min_value))
|
||||
print self.__class__.__name__, self.validators
|
||||
|
||||
assert choices, '`choices` argument is required and may not be empty'
|
||||
def to_native(self, data):
|
||||
try:
|
||||
data = int(str(data))
|
||||
except (ValueError, TypeError):
|
||||
self.fail('invalid')
|
||||
return data
|
||||
|
||||
def to_primative(self, value):
|
||||
if value is None:
|
||||
return None
|
||||
return int(value)
|
||||
|
||||
|
||||
class FloatField(Field):
|
||||
default_error_messages = {
|
||||
'invalid': _("'%s' value must be a float."),
|
||||
}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
max_value = kwargs.pop('max_value', None)
|
||||
min_value = kwargs.pop('min_value', None)
|
||||
super(FloatField, self).__init__(**kwargs)
|
||||
if max_value is not None:
|
||||
self.validators.append(validators.MaxValueValidator(max_value))
|
||||
if min_value is not None:
|
||||
self.validators.append(validators.MinValueValidator(min_value))
|
||||
|
||||
def to_primative(self, value):
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
self.fail('invalid', value=value)
|
||||
|
||||
def to_native(self, value):
|
||||
if value is None:
|
||||
return None
|
||||
return float(value)
|
||||
|
||||
|
||||
class DecimalField(Field):
|
||||
default_error_messages = {
|
||||
'invalid': _('Enter a number.'),
|
||||
'max_value': _('Ensure this value is less than or equal to {max_value}.'),
|
||||
'min_value': _('Ensure this value is greater than or equal to {min_value}.'),
|
||||
'max_digits': _('Ensure that there are no more than {max_digits} digits in total.'),
|
||||
'max_decimal_places': _('Ensure that there are no more than {max_decimal_places} decimal places.'),
|
||||
'max_whole_digits': _('Ensure that there are no more than {max_whole_digits} digits before the decimal point.')
|
||||
}
|
||||
|
||||
def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, **kwargs):
|
||||
self.max_value, self.min_value = max_value, min_value
|
||||
self.max_digits, self.max_decimal_places = max_digits, decimal_places
|
||||
super(DecimalField, self).__init__(**kwargs)
|
||||
if max_value is not None:
|
||||
self.validators.append(validators.MaxValueValidator(max_value))
|
||||
if min_value is not None:
|
||||
self.validators.append(validators.MinValueValidator(min_value))
|
||||
|
||||
def from_native(self, value):
|
||||
"""
|
||||
Validates that the input is a decimal number. Returns a Decimal
|
||||
instance. Returns None for empty values. Ensures that there are no more
|
||||
than max_digits in the number, and no more than decimal_places digits
|
||||
after the decimal point.
|
||||
"""
|
||||
if value in validators.EMPTY_VALUES:
|
||||
return None
|
||||
|
||||
value = smart_text(value).strip()
|
||||
try:
|
||||
value = decimal.Decimal(value)
|
||||
except decimal.DecimalException:
|
||||
self.fail('invalid')
|
||||
|
||||
# Check for NaN. It is the only value that isn't equal to itself,
|
||||
# so we can use this to identify NaN values.
|
||||
if value != value:
|
||||
self.fail('invalid')
|
||||
|
||||
# Check for infinity and negative infinity.
|
||||
if value in (decimal.Decimal('Inf'), decimal.Decimal('-Inf')):
|
||||
self.fail('invalid')
|
||||
|
||||
sign, digittuple, exponent = value.as_tuple()
|
||||
decimals = abs(exponent)
|
||||
# digittuple doesn't include any leading zeros.
|
||||
digits = len(digittuple)
|
||||
if decimals > digits:
|
||||
# We have leading zeros up to or past the decimal point. Count
|
||||
# everything past the decimal point as a digit. We do not count
|
||||
# 0 before the decimal point as a digit since that would mean
|
||||
# we would not allow max_digits = decimal_places.
|
||||
digits = decimals
|
||||
whole_digits = digits - decimals
|
||||
|
||||
if self.max_digits is not None and digits > self.max_digits:
|
||||
self.fail('max_digits', max_digits=self.max_digits)
|
||||
if self.decimal_places is not None and decimals > self.decimal_places:
|
||||
self.fail('max_decimal_places', max_decimal_places=self.max_decimal_places)
|
||||
if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places):
|
||||
self.fail('max_whole_digits', max_while_digits=self.max_digits - self.decimal_places)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
# Date & time fields...
|
||||
|
||||
class DateField(Field):
|
||||
default_error_messages = {
|
||||
'invalid': _("Date has wrong format. Use one of these formats instead: %s"),
|
||||
}
|
||||
input_formats = api_settings.DATE_INPUT_FORMATS
|
||||
format = api_settings.DATE_FORMAT
|
||||
|
||||
def __init__(self, input_formats=None, format=None, *args, **kwargs):
|
||||
self.input_formats = input_formats if input_formats is not None else self.input_formats
|
||||
self.format = format if format is not None else self.format
|
||||
super(DateField, self).__init__(*args, **kwargs)
|
||||
|
||||
def from_native(self, value):
|
||||
if value in validators.EMPTY_VALUES:
|
||||
return None
|
||||
|
||||
if isinstance(value, datetime.datetime):
|
||||
if timezone and settings.USE_TZ and timezone.is_aware(value):
|
||||
# Convert aware datetimes to the default time zone
|
||||
# before casting them to dates (#17742).
|
||||
default_timezone = timezone.get_default_timezone()
|
||||
value = timezone.make_naive(value, default_timezone)
|
||||
return value.date()
|
||||
if isinstance(value, datetime.date):
|
||||
return value
|
||||
|
||||
for format in self.input_formats:
|
||||
if format.lower() == ISO_8601:
|
||||
try:
|
||||
parsed = parse_date(value)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
else:
|
||||
if parsed is not None:
|
||||
return parsed
|
||||
else:
|
||||
try:
|
||||
parsed = datetime.datetime.strptime(value, format)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
else:
|
||||
return parsed.date()
|
||||
|
||||
humanized_format = humanize_datetime.date_formats(self.input_formats)
|
||||
msg = self.error_messages['invalid'] % humanized_format
|
||||
raise ValidationError(msg)
|
||||
|
||||
def to_primative(self, value):
|
||||
if value is None or self.format is None:
|
||||
return value
|
||||
|
||||
if isinstance(value, datetime.datetime):
|
||||
value = value.date()
|
||||
|
||||
if self.format.lower() == ISO_8601:
|
||||
return value.isoformat()
|
||||
return value.strftime(self.format)
|
||||
|
||||
|
||||
class DateTimeField(Field):
|
||||
default_error_messages = {
|
||||
'invalid': _("Datetime has wrong format. Use one of these formats instead: %s"),
|
||||
}
|
||||
input_formats = api_settings.DATETIME_INPUT_FORMATS
|
||||
format = api_settings.DATETIME_FORMAT
|
||||
|
||||
def __init__(self, input_formats=None, format=None, *args, **kwargs):
|
||||
self.input_formats = input_formats if input_formats is not None else self.input_formats
|
||||
self.format = format if format is not None else self.format
|
||||
super(DateTimeField, self).__init__(*args, **kwargs)
|
||||
|
||||
def from_native(self, value):
|
||||
if value in validators.EMPTY_VALUES:
|
||||
return None
|
||||
|
||||
if isinstance(value, datetime.datetime):
|
||||
return value
|
||||
if isinstance(value, datetime.date):
|
||||
value = datetime.datetime(value.year, value.month, value.day)
|
||||
if settings.USE_TZ:
|
||||
# For backwards compatibility, interpret naive datetimes in
|
||||
# local time. This won't work during DST change, but we can't
|
||||
# do much about it, so we let the exceptions percolate up the
|
||||
# call stack.
|
||||
warnings.warn("DateTimeField received a naive datetime (%s)"
|
||||
" while time zone support is active." % value,
|
||||
RuntimeWarning)
|
||||
default_timezone = timezone.get_default_timezone()
|
||||
value = timezone.make_aware(value, default_timezone)
|
||||
return value
|
||||
|
||||
for format in self.input_formats:
|
||||
if format.lower() == ISO_8601:
|
||||
try:
|
||||
parsed = parse_datetime(value)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
else:
|
||||
if parsed is not None:
|
||||
return parsed
|
||||
else:
|
||||
try:
|
||||
parsed = datetime.datetime.strptime(value, format)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
else:
|
||||
return parsed
|
||||
|
||||
humanized_format = humanize_datetime.datetime_formats(self.input_formats)
|
||||
msg = self.error_messages['invalid'] % humanized_format
|
||||
raise ValidationError(msg)
|
||||
|
||||
def to_primative(self, value):
|
||||
if value is None or self.format is None:
|
||||
return value
|
||||
|
||||
if self.format.lower() == ISO_8601:
|
||||
ret = value.isoformat()
|
||||
if ret.endswith('+00:00'):
|
||||
ret = ret[:-6] + 'Z'
|
||||
return ret
|
||||
return value.strftime(self.format)
|
||||
|
||||
|
||||
class TimeField(Field):
|
||||
default_error_messages = {
|
||||
'invalid': _("Time has wrong format. Use one of these formats instead: %s"),
|
||||
}
|
||||
input_formats = api_settings.TIME_INPUT_FORMATS
|
||||
format = api_settings.TIME_FORMAT
|
||||
|
||||
def __init__(self, input_formats=None, format=None, *args, **kwargs):
|
||||
self.input_formats = input_formats if input_formats is not None else self.input_formats
|
||||
self.format = format if format is not None else self.format
|
||||
super(TimeField, self).__init__(*args, **kwargs)
|
||||
|
||||
def from_native(self, value):
|
||||
if value in validators.EMPTY_VALUES:
|
||||
return None
|
||||
|
||||
if isinstance(value, datetime.time):
|
||||
return value
|
||||
|
||||
for format in self.input_formats:
|
||||
if format.lower() == ISO_8601:
|
||||
try:
|
||||
parsed = parse_time(value)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
else:
|
||||
if parsed is not None:
|
||||
return parsed
|
||||
else:
|
||||
try:
|
||||
parsed = datetime.datetime.strptime(value, format)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
else:
|
||||
return parsed.time()
|
||||
|
||||
humanized_format = humanize_datetime.time_formats(self.input_formats)
|
||||
msg = self.error_messages['invalid'] % humanized_format
|
||||
raise ValidationError(msg)
|
||||
|
||||
def to_primative(self, value):
|
||||
if value is None or self.format is None:
|
||||
return value
|
||||
|
||||
if isinstance(value, datetime.datetime):
|
||||
value = value.time()
|
||||
|
||||
if self.format.lower() == ISO_8601:
|
||||
return value.isoformat()
|
||||
return value.strftime(self.format)
|
||||
|
||||
|
||||
# Choice types...
|
||||
|
||||
class ChoiceField(Field):
|
||||
default_error_messages = {
|
||||
'invalid_choice': _('`{input}` is not a valid choice.')
|
||||
}
|
||||
|
||||
def __init__(self, choices, **kwargs):
|
||||
# Allow either single or paired choices style:
|
||||
# choices = [1, 2, 3]
|
||||
# choices = [(1, 'First'), (2, 'Second'), (3, 'Third')]
|
||||
|
@ -321,12 +694,14 @@ class ChoiceField(Field):
|
|||
except KeyError:
|
||||
self.fail('invalid_choice', input=data)
|
||||
|
||||
def to_primative(self, value):
|
||||
return value
|
||||
|
||||
|
||||
class MultipleChoiceField(ChoiceField):
|
||||
MESSAGES = {
|
||||
'required': 'This field is required.',
|
||||
'invalid_choice': '`{input}` is not a valid choice.',
|
||||
'not_a_list': 'Expected a list of items but got type `{input_type}`'
|
||||
default_error_messages = {
|
||||
'invalid_choice': _('`{input}` is not a valid choice.'),
|
||||
'not_a_list': _('Expected a list of items but got type `{input_type}`')
|
||||
}
|
||||
|
||||
def to_native(self, data):
|
||||
|
@ -337,72 +712,42 @@ class MultipleChoiceField(ChoiceField):
|
|||
for item in data
|
||||
])
|
||||
|
||||
|
||||
class IntegerField(Field):
|
||||
MESSAGES = {
|
||||
'required': 'This field is required.',
|
||||
'invalid_integer': 'A valid integer is required.'
|
||||
}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
max_value = kwargs.pop('max_value', None)
|
||||
min_value = kwargs.pop('min_value', None)
|
||||
super(IntegerField, self).__init__(**kwargs)
|
||||
if max_value is not None:
|
||||
self.validators.append(validators.MaxValueValidator(max_value))
|
||||
if min_value is not None:
|
||||
self.validators.append(validators.MinValueValidator(min_value))
|
||||
|
||||
def to_native(self, data):
|
||||
try:
|
||||
data = int(str(data))
|
||||
except (ValueError, TypeError):
|
||||
self.fail('invalid_integer')
|
||||
return data
|
||||
|
||||
def to_primative(self, value):
|
||||
if value is None:
|
||||
return None
|
||||
return int(value)
|
||||
return value
|
||||
|
||||
|
||||
class EmailField(CharField):
|
||||
pass # TODO
|
||||
|
||||
|
||||
class URLField(CharField):
|
||||
pass # TODO
|
||||
|
||||
|
||||
class RegexField(CharField):
|
||||
def __init__(self, **kwargs):
|
||||
self.regex = kwargs.pop('regex')
|
||||
super(CharField, self).__init__(**kwargs)
|
||||
|
||||
|
||||
class DateField(CharField):
|
||||
def __init__(self, **kwargs):
|
||||
self.input_formats = kwargs.pop('input_formats', None)
|
||||
super(DateField, self).__init__(**kwargs)
|
||||
|
||||
|
||||
class TimeField(CharField):
|
||||
def __init__(self, **kwargs):
|
||||
self.input_formats = kwargs.pop('input_formats', None)
|
||||
super(TimeField, self).__init__(**kwargs)
|
||||
|
||||
|
||||
class DateTimeField(CharField):
|
||||
def __init__(self, **kwargs):
|
||||
self.input_formats = kwargs.pop('input_formats', None)
|
||||
super(DateTimeField, self).__init__(**kwargs)
|
||||
|
||||
# File types...
|
||||
|
||||
class FileField(Field):
|
||||
pass # TODO
|
||||
|
||||
|
||||
class ImageField(Field):
|
||||
pass # TODO
|
||||
|
||||
|
||||
# Advanced field types...
|
||||
|
||||
class ReadOnlyField(Field):
|
||||
"""
|
||||
A read-only field that simply returns the field value.
|
||||
|
||||
If the field is a method with no parameters, the method will be called
|
||||
and it's return value used as the representation.
|
||||
|
||||
For example, the following would call `get_expiry_date()` on the object:
|
||||
|
||||
class ExampleSerializer(self):
|
||||
expiry_date = ReadOnlyField(source='get_expiry_date')
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs['read_only'] = True
|
||||
super(ReadOnlyField, self).__init__(**kwargs)
|
||||
|
||||
def to_native(self, data):
|
||||
raise NotImplemented('.to_native() not supported.')
|
||||
|
||||
def to_primative(self, value):
|
||||
if is_simple_callable(value):
|
||||
return value()
|
||||
|
@ -410,11 +755,28 @@ class ReadOnlyField(Field):
|
|||
|
||||
|
||||
class MethodField(Field):
|
||||
"""
|
||||
A read-only field that get its representation from calling a method on the
|
||||
parent serializer class. The method called will be of the form
|
||||
"get_{field_name}", and should take a single argument, which is the
|
||||
object being serialized.
|
||||
|
||||
For example:
|
||||
|
||||
class ExampleSerializer(self):
|
||||
extra_info = MethodField()
|
||||
|
||||
def get_extra_info(self, obj):
|
||||
return ... # Calculate some data to return.
|
||||
"""
|
||||
def __init__(self, **kwargs):
|
||||
kwargs['source'] = '*'
|
||||
kwargs['read_only'] = True
|
||||
super(MethodField, self).__init__(**kwargs)
|
||||
|
||||
def to_native(self, data):
|
||||
raise NotImplemented('.to_native() not supported.')
|
||||
|
||||
def to_primative(self, value):
|
||||
attr = 'get_{field_name}'.format(field_name=self.field_name)
|
||||
method = getattr(self.parent, attr)
|
||||
|
@ -424,35 +786,14 @@ class MethodField(Field):
|
|||
class ModelField(Field):
|
||||
"""
|
||||
A generic field that can be used against an arbitrary model field.
|
||||
|
||||
This is used by `ModelSerializer` when dealing with custom model fields,
|
||||
that do not have a serializer field to be mapped to.
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
try:
|
||||
self.model_field = kwargs.pop('model_field')
|
||||
except KeyError:
|
||||
raise ValueError("ModelField requires 'model_field' kwarg")
|
||||
|
||||
self.min_length = kwargs.pop('min_length',
|
||||
getattr(self.model_field, 'min_length', None))
|
||||
self.max_length = kwargs.pop('max_length',
|
||||
getattr(self.model_field, 'max_length', None))
|
||||
self.min_value = kwargs.pop('min_value',
|
||||
getattr(self.model_field, 'min_value', None))
|
||||
self.max_value = kwargs.pop('max_value',
|
||||
getattr(self.model_field, 'max_value', None))
|
||||
|
||||
super(ModelField, self).__init__(*args, **kwargs)
|
||||
|
||||
if self.min_length is not None:
|
||||
self.validators.append(validators.MinLengthValidator(self.min_length))
|
||||
if self.max_length is not None:
|
||||
self.validators.append(validators.MaxLengthValidator(self.max_length))
|
||||
if self.min_value is not None:
|
||||
self.validators.append(validators.MinValueValidator(self.min_value))
|
||||
if self.max_value is not None:
|
||||
self.validators.append(validators.MaxValueValidator(self.max_value))
|
||||
|
||||
def get_attribute(self, instance):
|
||||
return get_attribute(instance, self.source_attrs[:-1])
|
||||
def __init__(self, model_field, **kwargs):
|
||||
self.model_field = model_field
|
||||
kwargs['source'] = '*'
|
||||
super(ModelField, self).__init__(**kwargs)
|
||||
|
||||
def to_native(self, data):
|
||||
rel = getattr(self.model_field, 'rel', None)
|
||||
|
|
|
@ -10,15 +10,15 @@ python primitives.
|
|||
2. The process of marshalling between python primitives and request and
|
||||
response content is handled by parsers and renderers.
|
||||
"""
|
||||
from django.core import validators
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.db import models
|
||||
from django.utils import six
|
||||
from collections import namedtuple, OrderedDict
|
||||
from rest_framework.fields import empty, set_value, Field, SkipField
|
||||
from rest_framework.settings import api_settings
|
||||
from rest_framework.utils import html
|
||||
from rest_framework.utils import html, modelinfo, representation
|
||||
import copy
|
||||
import inspect
|
||||
|
||||
# Note: We do the following so that users of the framework can use this style:
|
||||
#
|
||||
|
@ -146,12 +146,10 @@ class SerializerMetaclass(type):
|
|||
class Serializer(BaseSerializer):
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
many = kwargs.pop('many', False)
|
||||
if many:
|
||||
class DynamicListSerializer(ListSerializer):
|
||||
child = cls()
|
||||
return DynamicListSerializer(*args, **kwargs)
|
||||
return super(Serializer, cls).__new__(cls)
|
||||
if kwargs.pop('many', False):
|
||||
kwargs['child'] = cls()
|
||||
return ListSerializer(*args, **kwargs)
|
||||
return super(Serializer, cls).__new__(cls, *args, **kwargs)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.context = kwargs.pop('context', {})
|
||||
|
@ -248,6 +246,9 @@ class Serializer(BaseSerializer):
|
|||
error = errors.get(field.field_name)
|
||||
yield FieldResult(field, value, error)
|
||||
|
||||
def __repr__(self):
|
||||
return representation.serializer_repr(self, indent=1)
|
||||
|
||||
|
||||
class ListSerializer(BaseSerializer):
|
||||
child = None
|
||||
|
@ -299,26 +300,8 @@ class ListSerializer(BaseSerializer):
|
|||
self.instance = self.create(self.validated_data)
|
||||
return self.instance
|
||||
|
||||
|
||||
def _resolve_model(obj):
|
||||
"""
|
||||
Resolve supplied `obj` to a Django model class.
|
||||
|
||||
`obj` must be a Django model class itself, or a string
|
||||
representation of one. Useful in situtations like GH #1225 where
|
||||
Django may not have resolved a string-based reference to a model in
|
||||
another model's foreign key definition.
|
||||
|
||||
String representations should have the format:
|
||||
'appname.ModelName'
|
||||
"""
|
||||
if isinstance(obj, six.string_types) and len(obj.split('.')) == 2:
|
||||
app_name, model_name = obj.split('.')
|
||||
return models.get_model(app_name, model_name)
|
||||
elif inspect.isclass(obj) and issubclass(obj, models.Model):
|
||||
return obj
|
||||
else:
|
||||
raise ValueError("{0} is not a Django model".format(obj))
|
||||
def __repr__(self):
|
||||
return representation.list_repr(self, indent=1)
|
||||
|
||||
|
||||
class ModelSerializerOptions(object):
|
||||
|
@ -334,24 +317,25 @@ class ModelSerializerOptions(object):
|
|||
class ModelSerializer(Serializer):
|
||||
field_mapping = {
|
||||
models.AutoField: IntegerField,
|
||||
# models.FloatField: FloatField,
|
||||
models.IntegerField: IntegerField,
|
||||
models.PositiveIntegerField: IntegerField,
|
||||
models.SmallIntegerField: IntegerField,
|
||||
models.PositiveSmallIntegerField: IntegerField,
|
||||
models.DateTimeField: DateTimeField,
|
||||
models.DateField: DateField,
|
||||
models.TimeField: TimeField,
|
||||
# models.DecimalField: DecimalField,
|
||||
models.EmailField: EmailField,
|
||||
models.CharField: CharField,
|
||||
models.URLField: URLField,
|
||||
# models.SlugField: SlugField,
|
||||
models.TextField: CharField,
|
||||
models.CommaSeparatedIntegerField: CharField,
|
||||
models.BigIntegerField: IntegerField,
|
||||
models.BooleanField: BooleanField,
|
||||
models.NullBooleanField: BooleanField,
|
||||
models.CharField: CharField,
|
||||
models.CommaSeparatedIntegerField: CharField,
|
||||
models.DateField: DateField,
|
||||
models.DateTimeField: DateTimeField,
|
||||
models.DecimalField: DecimalField,
|
||||
models.EmailField: EmailField,
|
||||
models.FileField: FileField,
|
||||
models.FloatField: FloatField,
|
||||
models.IntegerField: IntegerField,
|
||||
models.NullBooleanField: BooleanField,
|
||||
models.PositiveIntegerField: IntegerField,
|
||||
models.PositiveSmallIntegerField: IntegerField,
|
||||
models.SlugField: SlugField,
|
||||
models.SmallIntegerField: IntegerField,
|
||||
models.TextField: CharField,
|
||||
models.TimeField: TimeField,
|
||||
models.URLField: URLField,
|
||||
# models.ImageField: ImageField,
|
||||
}
|
||||
|
||||
|
@ -392,85 +376,31 @@ class ModelSerializer(Serializer):
|
|||
"""
|
||||
Return all the fields that should be serialized for the model.
|
||||
"""
|
||||
cls = self.opts.model
|
||||
opts = cls._meta.concrete_model._meta
|
||||
info = modelinfo.get_field_info(self.opts.model)
|
||||
ret = OrderedDict()
|
||||
nested = bool(self.opts.depth)
|
||||
|
||||
# Deal with adding the primary key field
|
||||
pk_field = opts.pk
|
||||
while pk_field.rel and pk_field.rel.parent_link:
|
||||
# If model is a child via multitable inheritance, use parent's pk
|
||||
pk_field = pk_field.rel.to._meta.pk
|
||||
|
||||
serializer_pk_field = self.get_pk_field(pk_field)
|
||||
serializer_pk_field = self.get_pk_field(info.pk)
|
||||
if serializer_pk_field:
|
||||
ret[pk_field.name] = serializer_pk_field
|
||||
ret[info.pk.name] = serializer_pk_field
|
||||
|
||||
# Deal with forward relationships
|
||||
forward_rels = [field for field in opts.fields if field.serialize]
|
||||
forward_rels += [field for field in opts.many_to_many if field.serialize]
|
||||
# Regular fields
|
||||
for field_name, field in info.fields.items():
|
||||
ret[field_name] = self.get_field(field)
|
||||
|
||||
for model_field in forward_rels:
|
||||
has_through_model = False
|
||||
|
||||
if model_field.rel:
|
||||
to_many = isinstance(model_field,
|
||||
models.fields.related.ManyToManyField)
|
||||
related_model = _resolve_model(model_field.rel.to)
|
||||
|
||||
if to_many and not model_field.rel.through._meta.auto_created:
|
||||
has_through_model = True
|
||||
|
||||
if model_field.rel and nested:
|
||||
field = self.get_nested_field(model_field, related_model, to_many)
|
||||
elif model_field.rel:
|
||||
field = self.get_related_field(model_field, related_model, to_many)
|
||||
# Forward relations
|
||||
for field_name, relation_info in info.forward_relations.items():
|
||||
if self.opts.depth:
|
||||
ret[field_name] = self.get_nested_field(*relation_info)
|
||||
else:
|
||||
field = self.get_field(model_field)
|
||||
ret[field_name] = self.get_related_field(*relation_info)
|
||||
|
||||
if field:
|
||||
if has_through_model:
|
||||
field.read_only = True
|
||||
|
||||
ret[model_field.name] = field
|
||||
|
||||
# Deal with reverse relationships
|
||||
if not self.opts.fields:
|
||||
reverse_rels = []
|
||||
else:
|
||||
# Reverse relationships are only included if they are explicitly
|
||||
# present in the `fields` option on the serializer
|
||||
reverse_rels = opts.get_all_related_objects()
|
||||
reverse_rels += opts.get_all_related_many_to_many_objects()
|
||||
|
||||
for relation in reverse_rels:
|
||||
accessor_name = relation.get_accessor_name()
|
||||
if not self.opts.fields or accessor_name not in self.opts.fields:
|
||||
continue
|
||||
related_model = relation.model
|
||||
to_many = relation.field.rel.multiple
|
||||
has_through_model = False
|
||||
is_m2m = isinstance(relation.field,
|
||||
models.fields.related.ManyToManyField)
|
||||
|
||||
if (
|
||||
is_m2m and
|
||||
hasattr(relation.field.rel, 'through') and
|
||||
not relation.field.rel.through._meta.auto_created
|
||||
):
|
||||
has_through_model = True
|
||||
|
||||
if nested:
|
||||
field = self.get_nested_field(None, related_model, to_many)
|
||||
else:
|
||||
field = self.get_related_field(None, related_model, to_many)
|
||||
|
||||
if field:
|
||||
if has_through_model:
|
||||
field.read_only = True
|
||||
|
||||
ret[accessor_name] = field
|
||||
# Reverse relations
|
||||
for accessor_name, relation_info in info.reverse_relations.items():
|
||||
if accessor_name in self.opts.fields:
|
||||
if self.opts.depth:
|
||||
ret[field_name] = self.get_nested_field(*relation_info)
|
||||
else:
|
||||
ret[field_name] = self.get_related_field(*relation_info)
|
||||
|
||||
return ret
|
||||
|
||||
|
@ -480,7 +410,7 @@ class ModelSerializer(Serializer):
|
|||
"""
|
||||
return self.get_field(model_field)
|
||||
|
||||
def get_nested_field(self, model_field, related_model, to_many):
|
||||
def get_nested_field(self, model_field, related_model, to_many, has_through_model):
|
||||
"""
|
||||
Creates a default instance of a nested relational field.
|
||||
|
||||
|
@ -491,59 +421,148 @@ class ModelSerializer(Serializer):
|
|||
model = related_model
|
||||
depth = self.opts.depth - 1
|
||||
|
||||
return NestedModelSerializer(many=to_many)
|
||||
kwargs = {'read_only': True}
|
||||
if to_many:
|
||||
kwargs['many'] = True
|
||||
return NestedModelSerializer(**kwargs)
|
||||
|
||||
def get_related_field(self, model_field, related_model, to_many):
|
||||
def get_related_field(self, model_field, related_model, to_many, has_through_model):
|
||||
"""
|
||||
Creates a default instance of a flat relational field.
|
||||
|
||||
Note that model_field will be `None` for reverse relationships.
|
||||
"""
|
||||
# TODO: filter queryset using:
|
||||
# .using(db).complex_filter(self.rel.limit_choices_to)
|
||||
kwargs = {
|
||||
'queryset': related_model._default_manager,
|
||||
}
|
||||
|
||||
kwargs = {}
|
||||
# 'queryset': related_model._default_manager,
|
||||
# 'many': to_many
|
||||
# }
|
||||
if to_many:
|
||||
kwargs['many'] = True
|
||||
|
||||
if has_through_model:
|
||||
kwargs['read_only'] = True
|
||||
kwargs.pop('queryset', None)
|
||||
|
||||
if model_field:
|
||||
kwargs['required'] = not(model_field.null or model_field.blank)
|
||||
if model_field.null or model_field.blank:
|
||||
kwargs['required'] = False
|
||||
# if model_field.help_text is not None:
|
||||
# kwargs['help_text'] = model_field.help_text
|
||||
if model_field.verbose_name is not None:
|
||||
kwargs['label'] = model_field.verbose_name
|
||||
if not model_field.editable:
|
||||
kwargs['read_only'] = True
|
||||
if model_field.verbose_name is not None:
|
||||
kwargs['label'] = model_field.verbose_name
|
||||
kwargs.pop('queryset', None)
|
||||
|
||||
return IntegerField(**kwargs)
|
||||
# TODO: return PrimaryKeyRelatedField(**kwargs)
|
||||
return PrimaryKeyRelatedField(**kwargs)
|
||||
|
||||
def get_field(self, model_field):
|
||||
"""
|
||||
Creates a default instance of a basic non-relational field.
|
||||
"""
|
||||
kwargs = {}
|
||||
validator_kwarg = model_field.validators
|
||||
|
||||
if model_field.null or model_field.blank:
|
||||
kwargs['required'] = False
|
||||
|
||||
if isinstance(model_field, models.AutoField) or not model_field.editable:
|
||||
kwargs['read_only'] = True
|
||||
|
||||
if model_field.has_default():
|
||||
kwargs['default'] = model_field.get_default()
|
||||
|
||||
if issubclass(model_field.__class__, models.TextField):
|
||||
kwargs['widget'] = widgets.Textarea
|
||||
|
||||
if model_field.verbose_name is not None:
|
||||
kwargs['label'] = model_field.verbose_name
|
||||
|
||||
if model_field.validators is not None:
|
||||
kwargs['validators'] = model_field.validators
|
||||
if isinstance(model_field, models.AutoField) or not model_field.editable:
|
||||
kwargs['read_only'] = True
|
||||
# Read only implies that the field is not required.
|
||||
# We have a cleaner repr on the instance if we don't set it.
|
||||
kwargs.pop('required', None)
|
||||
|
||||
if model_field.has_default():
|
||||
kwargs['default'] = model_field.get_default()
|
||||
# Having a default implies that the field is not required.
|
||||
# We have a cleaner repr on the instance if we don't set it.
|
||||
kwargs.pop('required', None)
|
||||
|
||||
# Ensure that max_length is passed explicitly as a keyword arg,
|
||||
# rather than as a validator.
|
||||
max_length = getattr(model_field, 'max_length', None)
|
||||
if max_length is not None:
|
||||
kwargs['max_length'] = max_length
|
||||
validator_kwarg = [
|
||||
validator for validator in validator_kwarg
|
||||
if not isinstance(validator, validators.MaxLengthValidator)
|
||||
]
|
||||
|
||||
# Ensure that min_length is passed explicitly as a keyword arg,
|
||||
# rather than as a validator.
|
||||
min_length = getattr(model_field, 'min_length', None)
|
||||
if min_length is not None:
|
||||
kwargs['min_length'] = min_length
|
||||
validator_kwarg = [
|
||||
validator for validator in validator_kwarg
|
||||
if not isinstance(validator, validators.MinLengthValidator)
|
||||
]
|
||||
|
||||
# Ensure that max_value is passed explicitly as a keyword arg,
|
||||
# rather than as a validator.
|
||||
max_value = next((
|
||||
validator.limit_value for validator in validator_kwarg
|
||||
if isinstance(validator, validators.MaxValueValidator)
|
||||
), None)
|
||||
if max_value is not None:
|
||||
kwargs['max_value'] = max_value
|
||||
validator_kwarg = [
|
||||
validator for validator in validator_kwarg
|
||||
if not isinstance(validator, validators.MaxValueValidator)
|
||||
]
|
||||
|
||||
# Ensure that max_value is passed explicitly as a keyword arg,
|
||||
# rather than as a validator.
|
||||
min_value = next((
|
||||
validator.limit_value for validator in validator_kwarg
|
||||
if isinstance(validator, validators.MinValueValidator)
|
||||
), None)
|
||||
if min_value is not None:
|
||||
kwargs['min_value'] = min_value
|
||||
validator_kwarg = [
|
||||
validator for validator in validator_kwarg
|
||||
if not isinstance(validator, validators.MinValueValidator)
|
||||
]
|
||||
|
||||
# URLField does not need to include the URLValidator argument,
|
||||
# as it is explicitly added in.
|
||||
if isinstance(model_field, models.URLField):
|
||||
validator_kwarg = [
|
||||
validator for validator in validator_kwarg
|
||||
if not isinstance(validator, validators.URLValidator)
|
||||
]
|
||||
|
||||
# EmailField does not need to include the validate_email argument,
|
||||
# as it is explicitly added in.
|
||||
if isinstance(model_field, models.EmailField):
|
||||
validator_kwarg = [
|
||||
validator for validator in validator_kwarg
|
||||
if validator is not validators.validate_email
|
||||
]
|
||||
|
||||
# SlugField do not need to include the 'validate_slug' argument,
|
||||
if isinstance(model_field, models.SlugField):
|
||||
validator_kwarg = [
|
||||
validator for validator in validator_kwarg
|
||||
if validator is not validators.validate_slug
|
||||
]
|
||||
|
||||
max_digits = getattr(model_field, 'max_digits', None)
|
||||
if max_digits is not None:
|
||||
kwargs['max_digits'] = max_digits
|
||||
|
||||
decimal_places = getattr(model_field, 'decimal_places', None)
|
||||
if decimal_places is not None:
|
||||
kwargs['decimal_places'] = decimal_places
|
||||
|
||||
if validator_kwarg:
|
||||
kwargs['validators'] = validator_kwarg
|
||||
|
||||
# if issubclass(model_field.__class__, models.TextField):
|
||||
# kwargs['widget'] = widgets.Textarea
|
||||
|
||||
# if model_field.help_text is not None:
|
||||
# kwargs['help_text'] = model_field.help_text
|
||||
|
@ -555,31 +574,10 @@ class ModelSerializer(Serializer):
|
|||
kwargs['empty'] = None
|
||||
return ChoiceField(**kwargs)
|
||||
|
||||
# put this below the ChoiceField because min_value isn't a valid initializer
|
||||
if issubclass(model_field.__class__, models.PositiveIntegerField) or \
|
||||
issubclass(model_field.__class__, models.PositiveSmallIntegerField):
|
||||
kwargs['min_value'] = 0
|
||||
|
||||
if model_field.null and \
|
||||
issubclass(model_field.__class__, (models.CharField, models.TextField)):
|
||||
kwargs['allow_none'] = True
|
||||
|
||||
# attribute_dict = {
|
||||
# models.CharField: ['max_length'],
|
||||
# models.CommaSeparatedIntegerField: ['max_length'],
|
||||
# models.DecimalField: ['max_digits', 'decimal_places'],
|
||||
# models.EmailField: ['max_length'],
|
||||
# models.FileField: ['max_length'],
|
||||
# models.ImageField: ['max_length'],
|
||||
# models.SlugField: ['max_length'],
|
||||
# models.URLField: ['max_length'],
|
||||
# }
|
||||
|
||||
# if model_field.__class__ in attribute_dict:
|
||||
# attributes = attribute_dict[model_field.__class__]
|
||||
# for attribute in attributes:
|
||||
# kwargs.update({attribute: getattr(model_field, attribute)})
|
||||
|
||||
try:
|
||||
return self.field_mapping[model_field.__class__](**kwargs)
|
||||
except KeyError:
|
||||
|
@ -594,28 +592,21 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
|
|||
super(HyperlinkedModelSerializerOptions, self).__init__(meta)
|
||||
self.view_name = getattr(meta, 'view_name', None)
|
||||
self.lookup_field = getattr(meta, 'lookup_field', None)
|
||||
self.url_field_name = getattr(meta, 'url_field_name', api_settings.URL_FIELD_NAME)
|
||||
|
||||
|
||||
class HyperlinkedModelSerializer(ModelSerializer):
|
||||
_options_class = HyperlinkedModelSerializerOptions
|
||||
_default_view_name = '%(model_name)s-detail'
|
||||
_hyperlink_field_class = HyperlinkedRelatedField
|
||||
_hyperlink_identify_field_class = HyperlinkedIdentityField
|
||||
|
||||
def get_default_fields(self):
|
||||
fields = super(HyperlinkedModelSerializer, self).get_default_fields()
|
||||
|
||||
if self.opts.view_name is None:
|
||||
self.opts.view_name = self._get_default_view_name(self.opts.model)
|
||||
self.opts.view_name = self.get_default_view_name(self.opts.model)
|
||||
|
||||
if self.opts.url_field_name not in fields:
|
||||
url_field = self._hyperlink_identify_field_class(
|
||||
view_name=self.opts.view_name,
|
||||
lookup_field=self.opts.lookup_field
|
||||
)
|
||||
url_field_name = api_settings.URL_FIELD_NAME
|
||||
if url_field_name not in fields:
|
||||
ret = fields.__class__()
|
||||
ret[self.opts.url_field_name] = url_field
|
||||
ret[url_field_name] = self.get_url_field()
|
||||
ret.update(fields)
|
||||
fields = ret
|
||||
|
||||
|
@ -625,39 +616,48 @@ class HyperlinkedModelSerializer(ModelSerializer):
|
|||
if self.opts.fields and model_field.name in self.opts.fields:
|
||||
return self.get_field(model_field)
|
||||
|
||||
def get_related_field(self, model_field, related_model, to_many):
|
||||
def get_url_field(self):
|
||||
kwargs = {
|
||||
'view_name': self.get_default_view_name(self.opts.model)
|
||||
}
|
||||
if self.opts.lookup_field:
|
||||
kwargs['lookup_field'] = self.opts.lookup_field
|
||||
return HyperlinkedIdentityField(**kwargs)
|
||||
|
||||
def get_related_field(self, model_field, related_model, to_many, has_through_model):
|
||||
"""
|
||||
Creates a default instance of a flat relational field.
|
||||
"""
|
||||
# TODO: filter queryset using:
|
||||
# .using(db).complex_filter(self.rel.limit_choices_to)
|
||||
# kwargs = {
|
||||
# 'queryset': related_model._default_manager,
|
||||
# 'view_name': self._get_default_view_name(related_model),
|
||||
# 'many': to_many
|
||||
# }
|
||||
kwargs = {}
|
||||
kwargs = {
|
||||
'queryset': related_model._default_manager,
|
||||
'view_name': self.get_default_view_name(related_model),
|
||||
}
|
||||
|
||||
if to_many:
|
||||
kwargs['many'] = True
|
||||
|
||||
if has_through_model:
|
||||
kwargs['read_only'] = True
|
||||
kwargs.pop('queryset', None)
|
||||
|
||||
if model_field:
|
||||
kwargs['required'] = not(model_field.null or model_field.blank)
|
||||
if model_field.null or model_field.blank:
|
||||
kwargs['required'] = False
|
||||
# if model_field.help_text is not None:
|
||||
# kwargs['help_text'] = model_field.help_text
|
||||
if model_field.verbose_name is not None:
|
||||
kwargs['label'] = model_field.verbose_name
|
||||
if not model_field.editable:
|
||||
kwargs['read_only'] = True
|
||||
kwargs.pop('queryset', None)
|
||||
|
||||
return IntegerField(**kwargs)
|
||||
# if self.opts.lookup_field:
|
||||
# kwargs['lookup_field'] = self.opts.lookup_field
|
||||
return HyperlinkedRelatedField(**kwargs)
|
||||
|
||||
# return self._hyperlink_field_class(**kwargs)
|
||||
|
||||
def _get_default_view_name(self, model):
|
||||
def get_default_view_name(self, model):
|
||||
"""
|
||||
Return the view name to use if 'view_name' is not specified in 'Meta'
|
||||
Return the view name to use for related models.
|
||||
"""
|
||||
model_meta = model._meta
|
||||
format_kwargs = {
|
||||
'app_label': model_meta.app_label,
|
||||
'model_name': model_meta.object_name.lower()
|
||||
return '%(model_name)s-detail' % {
|
||||
'app_label': model._meta.app_label,
|
||||
'model_name': model._meta.object_name.lower()
|
||||
}
|
||||
return self._default_view_name % format_kwargs
|
||||
|
|
47
rest_framework/utils/humanize_datetime.py
Normal file
47
rest_framework/utils/humanize_datetime.py
Normal file
|
@ -0,0 +1,47 @@
|
|||
"""
|
||||
Helper functions that convert strftime formats into more readable representations.
|
||||
"""
|
||||
from rest_framework import ISO_8601
|
||||
|
||||
|
||||
def datetime_formats(formats):
|
||||
format = ', '.join(formats).replace(
|
||||
ISO_8601,
|
||||
'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]'
|
||||
)
|
||||
return humanize_strptime(format)
|
||||
|
||||
|
||||
def date_formats(formats):
|
||||
format = ', '.join(formats).replace(ISO_8601, 'YYYY[-MM[-DD]]')
|
||||
return humanize_strptime(format)
|
||||
|
||||
|
||||
def time_formats(formats):
|
||||
format = ', '.join(formats).replace(ISO_8601, 'hh:mm[:ss[.uuuuuu]]')
|
||||
return humanize_strptime(format)
|
||||
|
||||
|
||||
def humanize_strptime(format_string):
|
||||
# Note that we're missing some of the locale specific mappings that
|
||||
# don't really make sense.
|
||||
mapping = {
|
||||
"%Y": "YYYY",
|
||||
"%y": "YY",
|
||||
"%m": "MM",
|
||||
"%b": "[Jan-Dec]",
|
||||
"%B": "[January-December]",
|
||||
"%d": "DD",
|
||||
"%H": "hh",
|
||||
"%I": "hh", # Requires '%p' to differentiate from '%H'.
|
||||
"%M": "mm",
|
||||
"%S": "ss",
|
||||
"%f": "uuuuuu",
|
||||
"%a": "[Mon-Sun]",
|
||||
"%A": "[Monday-Sunday]",
|
||||
"%p": "[AM|PM]",
|
||||
"%z": "[+HHMM|-HHMM]"
|
||||
}
|
||||
for key, val in mapping.items():
|
||||
format_string = format_string.replace(key, val)
|
||||
return format_string
|
97
rest_framework/utils/modelinfo.py
Normal file
97
rest_framework/utils/modelinfo.py
Normal file
|
@ -0,0 +1,97 @@
|
|||
"""
|
||||
Helper functions for returning the field information that is associated
|
||||
with a model class.
|
||||
"""
|
||||
from collections import namedtuple, OrderedDict
|
||||
from django.db import models
|
||||
from django.utils import six
|
||||
import inspect
|
||||
|
||||
FieldInfo = namedtuple('FieldResult', ['pk', 'fields', 'forward_relations', 'reverse_relations'])
|
||||
RelationInfo = namedtuple('RelationInfo', ['field', 'related', 'to_many', 'has_through_model'])
|
||||
|
||||
|
||||
def _resolve_model(obj):
|
||||
"""
|
||||
Resolve supplied `obj` to a Django model class.
|
||||
|
||||
`obj` must be a Django model class itself, or a string
|
||||
representation of one. Useful in situtations like GH #1225 where
|
||||
Django may not have resolved a string-based reference to a model in
|
||||
another model's foreign key definition.
|
||||
|
||||
String representations should have the format:
|
||||
'appname.ModelName'
|
||||
"""
|
||||
if isinstance(obj, six.string_types) and len(obj.split('.')) == 2:
|
||||
app_name, model_name = obj.split('.')
|
||||
return models.get_model(app_name, model_name)
|
||||
elif inspect.isclass(obj) and issubclass(obj, models.Model):
|
||||
return obj
|
||||
raise ValueError("{0} is not a Django model".format(obj))
|
||||
|
||||
|
||||
def get_field_info(model):
|
||||
"""
|
||||
Given a model class, returns a `FieldInfo` instance containing metadata
|
||||
about the various field types on the model.
|
||||
"""
|
||||
opts = model._meta.concrete_model._meta
|
||||
|
||||
# Deal with the primary key.
|
||||
pk = opts.pk
|
||||
while pk.rel and pk.rel.parent_link:
|
||||
# If model is a child via multitable inheritance, use parent's pk.
|
||||
pk = pk.rel.to._meta.pk
|
||||
|
||||
# Deal with regular fields.
|
||||
fields = OrderedDict()
|
||||
for field in [field for field in opts.fields if field.serialize and not field.rel]:
|
||||
fields[field.name] = field
|
||||
|
||||
# Deal with forward relationships.
|
||||
forward_relations = OrderedDict()
|
||||
for field in [field for field in opts.fields if field.serialize and field.rel]:
|
||||
forward_relations[field.name] = RelationInfo(
|
||||
field=field,
|
||||
related=_resolve_model(field.rel.to),
|
||||
to_many=False,
|
||||
has_through_model=False
|
||||
)
|
||||
|
||||
# Deal with forward many-to-many relationships.
|
||||
for field in [field for field in opts.many_to_many if field.serialize]:
|
||||
forward_relations[field.name] = RelationInfo(
|
||||
field=field,
|
||||
related=_resolve_model(field.rel.to),
|
||||
to_many=True,
|
||||
has_through_model=(
|
||||
not field.rel.through._meta.auto_created
|
||||
)
|
||||
)
|
||||
|
||||
# Deal with reverse relationships.
|
||||
reverse_relations = OrderedDict()
|
||||
for relation in opts.get_all_related_objects():
|
||||
accessor_name = relation.get_accessor_name()
|
||||
reverse_relations[accessor_name] = RelationInfo(
|
||||
field=None,
|
||||
related=relation.model,
|
||||
to_many=relation.field.rel.multiple,
|
||||
has_through_model=False
|
||||
)
|
||||
|
||||
# Deal with reverse many-to-many relationships.
|
||||
for relation in opts.get_all_related_many_to_many_objects():
|
||||
accessor_name = relation.get_accessor_name()
|
||||
reverse_relations[accessor_name] = RelationInfo(
|
||||
field=None,
|
||||
related=relation.model,
|
||||
to_many=True,
|
||||
has_through_model=(
|
||||
hasattr(relation.field.rel, 'through') and
|
||||
not relation.field.rel.through._meta.auto_created
|
||||
)
|
||||
)
|
||||
|
||||
return FieldInfo(pk, fields, forward_relations, reverse_relations)
|
72
rest_framework/utils/representation.py
Normal file
72
rest_framework/utils/representation.py
Normal file
|
@ -0,0 +1,72 @@
|
|||
"""
|
||||
Helper functions for creating user-friendly representations
|
||||
of serializer classes and serializer fields.
|
||||
"""
|
||||
import re
|
||||
|
||||
|
||||
def smart_repr(value):
|
||||
value = repr(value)
|
||||
|
||||
# Representations like u'help text'
|
||||
# should simply be presented as 'help text'
|
||||
if value.startswith("u'") and value.endswith("'"):
|
||||
return value[1:]
|
||||
|
||||
# Representations like
|
||||
# <django.core.validators.RegexValidator object at 0x1047af050>
|
||||
# Should be presented as
|
||||
# <django.core.validators.RegexValidator object>
|
||||
value = re.sub(' at 0x[0-9a-f]{8,10}>', '>', value)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def field_repr(field, force_many=False):
|
||||
kwargs = field._kwargs
|
||||
if force_many:
|
||||
kwargs = kwargs.copy()
|
||||
kwargs['many'] = True
|
||||
kwargs.pop('child', None)
|
||||
|
||||
arg_string = ', '.join([smart_repr(val) for val in field._args])
|
||||
kwarg_string = ', '.join([
|
||||
'%s=%s' % (key, smart_repr(val))
|
||||
for key, val in sorted(kwargs.items())
|
||||
])
|
||||
if arg_string and kwarg_string:
|
||||
arg_string += ', '
|
||||
|
||||
if force_many:
|
||||
class_name = force_many.__class__.__name__
|
||||
else:
|
||||
class_name = field.__class__.__name__
|
||||
|
||||
return "%s(%s%s)" % (class_name, arg_string, kwarg_string)
|
||||
|
||||
|
||||
def serializer_repr(serializer, indent, force_many=None):
|
||||
ret = field_repr(serializer, force_many) + ':'
|
||||
indent_str = ' ' * indent
|
||||
|
||||
if force_many:
|
||||
fields = force_many.fields
|
||||
else:
|
||||
fields = serializer.fields
|
||||
|
||||
for field_name, field in fields.items():
|
||||
ret += '\n' + indent_str + field_name + ' = '
|
||||
if hasattr(field, 'fields'):
|
||||
ret += serializer_repr(field, indent + 1)
|
||||
elif hasattr(field, 'child'):
|
||||
ret += list_repr(field, indent + 1)
|
||||
else:
|
||||
ret += field_repr(field)
|
||||
return ret
|
||||
|
||||
|
||||
def list_repr(serializer, indent):
|
||||
child = serializer.child
|
||||
if hasattr(child, 'fields'):
|
||||
return serializer_repr(serializer, indent, force_many=child)
|
||||
return field_repr(serializer)
|
160
tests/test_model_field_mappings.py
Normal file
160
tests/test_model_field_mappings.py
Normal file
|
@ -0,0 +1,160 @@
|
|||
"""
|
||||
The `ModelSerializer` and `HyperlinkedModelSerializer` classes are essentially
|
||||
shortcuts for automatically creating serializers based on a given model class.
|
||||
|
||||
These tests deal with ensuring that we correctly map the model fields onto
|
||||
an appropriate set of serializer fields for each case.
|
||||
"""
|
||||
from django.db import models
|
||||
from django.test import TestCase
|
||||
from rest_framework import serializers
|
||||
|
||||
|
||||
# Models for testing regular field mapping
|
||||
|
||||
class RegularFieldsModel(models.Model):
|
||||
auto_field = models.AutoField(primary_key=True)
|
||||
big_integer_field = models.BigIntegerField()
|
||||
boolean_field = models.BooleanField()
|
||||
char_field = models.CharField(max_length=100)
|
||||
comma_seperated_integer_field = models.CommaSeparatedIntegerField(max_length=100)
|
||||
date_field = models.DateField()
|
||||
datetime_field = models.DateTimeField()
|
||||
decimal_field = models.DecimalField(max_digits=3, decimal_places=1)
|
||||
email_field = models.EmailField(max_length=100)
|
||||
float_field = models.FloatField()
|
||||
integer_field = models.IntegerField()
|
||||
null_boolean_field = models.NullBooleanField()
|
||||
positive_integer_field = models.PositiveIntegerField()
|
||||
positive_small_integer_field = models.PositiveSmallIntegerField()
|
||||
slug_field = models.SlugField(max_length=100)
|
||||
small_integer_field = models.SmallIntegerField()
|
||||
text_field = models.TextField()
|
||||
time_field = models.TimeField()
|
||||
url_field = models.URLField(max_length=100)
|
||||
|
||||
|
||||
REGULAR_FIELDS_REPR = """
|
||||
TestSerializer():
|
||||
auto_field = IntegerField(label='auto field', read_only=True)
|
||||
big_integer_field = IntegerField(label='big integer field')
|
||||
boolean_field = BooleanField(default=False, label='boolean field')
|
||||
char_field = CharField(label='char field', max_length=100)
|
||||
comma_seperated_integer_field = CharField(label='comma seperated integer field', max_length=100, validators=[<django.core.validators.RegexValidator object>])
|
||||
date_field = DateField(label='date field')
|
||||
datetime_field = DateTimeField(label='datetime field')
|
||||
decimal_field = DecimalField(decimal_places=1, label='decimal field', max_digits=3)
|
||||
email_field = EmailField(label='email field', max_length=100)
|
||||
float_field = FloatField(label='float field')
|
||||
integer_field = IntegerField(label='integer field')
|
||||
null_boolean_field = BooleanField(label='null boolean field', required=False)
|
||||
positive_integer_field = IntegerField(label='positive integer field')
|
||||
positive_small_integer_field = IntegerField(label='positive small integer field')
|
||||
slug_field = SlugField(label='slug field', max_length=100)
|
||||
small_integer_field = IntegerField(label='small integer field')
|
||||
text_field = CharField(label='text field')
|
||||
time_field = TimeField(label='time field')
|
||||
url_field = URLField(label='url field', max_length=100)
|
||||
""".strip()
|
||||
|
||||
|
||||
# Model for testing relational field mapping
|
||||
|
||||
class ForeignKeyTarget(models.Model):
|
||||
char_field = models.CharField(max_length=100)
|
||||
|
||||
|
||||
class ManyToManyTarget(models.Model):
|
||||
char_field = models.CharField(max_length=100)
|
||||
|
||||
|
||||
class OneToOneTarget(models.Model):
|
||||
char_field = models.CharField(max_length=100)
|
||||
|
||||
|
||||
class RelationalModel(models.Model):
|
||||
foreign_key = models.ForeignKey(ForeignKeyTarget)
|
||||
many_to_many = models.ManyToManyField(ManyToManyTarget)
|
||||
one_to_one = models.OneToOneField(OneToOneTarget)
|
||||
|
||||
|
||||
RELATIONAL_FLAT_REPR = """
|
||||
TestSerializer():
|
||||
id = IntegerField(label='ID', read_only=True)
|
||||
foreign_key = PrimaryKeyRelatedField(label='foreign key', queryset=<django.db.models.manager.Manager object>)
|
||||
one_to_one = PrimaryKeyRelatedField(label='one to one', queryset=<django.db.models.manager.Manager object>)
|
||||
many_to_many = PrimaryKeyRelatedField(label='many to many', many=True, queryset=<django.db.models.manager.Manager object>)
|
||||
""".strip()
|
||||
|
||||
|
||||
RELATIONAL_NESTED_REPR = """
|
||||
TestSerializer():
|
||||
id = IntegerField(label='ID', read_only=True)
|
||||
foreign_key = NestedModelSerializer(read_only=True):
|
||||
id = IntegerField(label='ID', read_only=True)
|
||||
name = CharField(label='name', max_length=100)
|
||||
one_to_one = NestedModelSerializer(read_only=True):
|
||||
id = IntegerField(label='ID', read_only=True)
|
||||
name = CharField(label='name', max_length=100)
|
||||
many_to_many = NestedModelSerializer(many=True, read_only=True):
|
||||
id = IntegerField(label='ID', read_only=True)
|
||||
name = CharField(label='name', max_length=100)
|
||||
""".strip()
|
||||
|
||||
|
||||
HYPERLINKED_FLAT_REPR = """
|
||||
TestSerializer():
|
||||
url = HyperlinkedIdentityField(view_name='relationalmodel-detail')
|
||||
foreign_key = HyperlinkedRelatedField(label='foreign key', queryset=<django.db.models.manager.Manager object>, view_name='foreignkeytarget-detail')
|
||||
one_to_one = HyperlinkedRelatedField(label='one to one', queryset=<django.db.models.manager.Manager object>, view_name='onetoonetarget-detail')
|
||||
many_to_many = HyperlinkedRelatedField(label='many to many', many=True, queryset=<django.db.models.manager.Manager object>, view_name='manytomanytarget-detail')
|
||||
""".strip()
|
||||
|
||||
|
||||
HYPERLINKED_NESTED_REPR = """
|
||||
TestSerializer():
|
||||
url = HyperlinkedIdentityField(view_name='relationalmodel-detail')
|
||||
foreign_key = NestedModelSerializer(read_only=True):
|
||||
id = IntegerField(label='ID', read_only=True)
|
||||
name = CharField(label='name', max_length=100)
|
||||
one_to_one = NestedModelSerializer(read_only=True):
|
||||
id = IntegerField(label='ID', read_only=True)
|
||||
name = CharField(label='name', max_length=100)
|
||||
many_to_many = NestedModelSerializer(many=True, read_only=True):
|
||||
id = IntegerField(label='ID', read_only=True)
|
||||
name = CharField(label='name', max_length=100)
|
||||
""".strip()
|
||||
|
||||
|
||||
class TestSerializerMappings(TestCase):
|
||||
def test_regular_fields(self):
|
||||
class TestSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = RegularFieldsModel
|
||||
self.assertEqual(repr(TestSerializer()), REGULAR_FIELDS_REPR)
|
||||
|
||||
def test_flat_relational_fields(self):
|
||||
class TestSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = RelationalModel
|
||||
self.assertEqual(repr(TestSerializer()), RELATIONAL_FLAT_REPR)
|
||||
|
||||
def test_nested_relational_fields(self):
|
||||
class TestSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = RelationalModel
|
||||
depth = 1
|
||||
self.assertEqual(repr(TestSerializer()), RELATIONAL_NESTED_REPR)
|
||||
|
||||
def test_flat_hyperlinked_fields(self):
|
||||
class TestSerializer(serializers.HyperlinkedModelSerializer):
|
||||
class Meta:
|
||||
model = RelationalModel
|
||||
self.assertEqual(repr(TestSerializer()), HYPERLINKED_FLAT_REPR)
|
||||
|
||||
def test_nested_hyperlinked_fields(self):
|
||||
class TestSerializer(serializers.HyperlinkedModelSerializer):
|
||||
class Meta:
|
||||
model = RelationalModel
|
||||
depth = 1
|
||||
self.assertEqual(repr(TestSerializer()), HYPERLINKED_NESTED_REPR)
|
|
@ -1,6 +1,6 @@
|
|||
from django.test import TestCase
|
||||
from django.utils import six
|
||||
from rest_framework.serializers import _resolve_model
|
||||
from rest_framework.utils.modelinfo import _resolve_model
|
||||
from tests.models import BasicModel
|
||||
|
||||
|
|
@ -22,18 +22,18 @@
|
|||
# https://github.com/tomchristie/django-rest-framework/issues/446
|
||||
# """
|
||||
# field = serializers.PrimaryKeyRelatedField(queryset=NullModel.objects.all())
|
||||
# self.assertRaises(serializers.ValidationError, field.from_native, '')
|
||||
# self.assertRaises(serializers.ValidationError, field.from_native, [])
|
||||
# self.assertRaises(serializers.ValidationError, field.to_primative, '')
|
||||
# self.assertRaises(serializers.ValidationError, field.to_primative, [])
|
||||
|
||||
# def test_hyperlinked_related_field_with_empty_string(self):
|
||||
# field = serializers.HyperlinkedRelatedField(queryset=NullModel.objects.all(), view_name='')
|
||||
# self.assertRaises(serializers.ValidationError, field.from_native, '')
|
||||
# self.assertRaises(serializers.ValidationError, field.from_native, [])
|
||||
# self.assertRaises(serializers.ValidationError, field.to_primative, '')
|
||||
# self.assertRaises(serializers.ValidationError, field.to_primative, [])
|
||||
|
||||
# def test_slug_related_field_with_empty_string(self):
|
||||
# field = serializers.SlugRelatedField(queryset=NullModel.objects.all(), slug_field='pk')
|
||||
# self.assertRaises(serializers.ValidationError, field.from_native, '')
|
||||
# self.assertRaises(serializers.ValidationError, field.from_native, [])
|
||||
# self.assertRaises(serializers.ValidationError, field.to_primative, '')
|
||||
# self.assertRaises(serializers.ValidationError, field.to_primative, [])
|
||||
|
||||
|
||||
# class TestManyRelatedMixin(TestCase):
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
# def test_empty_serializer(self):
|
||||
# class FooBarSerializer(serializers.Serializer):
|
||||
# foo = serializers.IntegerField()
|
||||
# bar = serializers.SerializerMethodField('get_bar')
|
||||
# bar = serializers.MethodField()
|
||||
|
||||
# def get_bar(self, obj):
|
||||
# return 'bar'
|
||||
|
|
Loading…
Reference in New Issue
Block a user