Fleshing out serializer fields

This commit is contained in:
Tom Christie 2014-09-09 17:46:28 +01:00
parent 21980b800d
commit b1c07670ca
9 changed files with 1053 additions and 336 deletions

View File

@ -1,8 +1,18 @@
from django.conf import settings
from django.core import validators from django.core import validators
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.utils import timezone
from django.utils.dateparse import parse_date, parse_datetime, parse_time
from django.utils.encoding import is_protected_type from django.utils.encoding import is_protected_type
from rest_framework.utils import html from django.utils.translation import ugettext_lazy as _
from rest_framework import ISO_8601
from rest_framework.compat import smart_text
from rest_framework.settings import api_settings
from rest_framework.utils import html, representation, humanize_datetime
import datetime
import decimal
import inspect import inspect
import warnings
class empty: class empty:
@ -71,22 +81,22 @@ class SkipField(Exception):
pass pass
NOT_READ_ONLY_WRITE_ONLY = 'May not set both `read_only` and `write_only`'
NOT_READ_ONLY_REQUIRED = 'May not set both `read_only` and `required`'
NOT_READ_ONLY_DEFAULT = 'May not set both `read_only` and `default`'
NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`'
MISSING_ERROR_MESSAGE = (
'ValidationError raised by `{class_name}`, but error key `{key}` does '
'not exist in the `error_messages` dictionary.'
)
class Field(object): class Field(object):
_creation_counter = 0 _creation_counter = 0
MESSAGES = { default_error_messages = {
'required': 'This field is required.' 'required': _('This field is required.')
} }
_NOT_READ_ONLY_WRITE_ONLY = 'May not set both `read_only` and `write_only`'
_NOT_READ_ONLY_REQUIRED = 'May not set both `read_only` and `required`'
_NOT_READ_ONLY_DEFAULT = 'May not set both `read_only` and `default`'
_NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`'
_MISSING_ERROR_MESSAGE = (
'ValidationError raised by `{class_name}`, but error key `{key}` does '
'not exist in the `MESSAGES` dictionary.'
)
default_validators = [] default_validators = []
def __init__(self, read_only=False, write_only=False, def __init__(self, read_only=False, write_only=False,
@ -100,10 +110,10 @@ class Field(object):
required = default is empty and not read_only required = default is empty and not read_only
# Some combinations of keyword arguments do not make sense. # Some combinations of keyword arguments do not make sense.
assert not (read_only and write_only), self._NOT_READ_ONLY_WRITE_ONLY assert not (read_only and write_only), NOT_READ_ONLY_WRITE_ONLY
assert not (read_only and required), self._NOT_READ_ONLY_REQUIRED assert not (read_only and required), NOT_READ_ONLY_REQUIRED
assert not (read_only and default is not empty), self._NOT_READ_ONLY_DEFAULT assert not (read_only and default is not empty), NOT_READ_ONLY_DEFAULT
assert not (required and default is not empty), self._NOT_REQUIRED_DEFAULT assert not (required and default is not empty), NOT_REQUIRED_DEFAULT
self.read_only = read_only self.read_only = read_only
self.write_only = write_only self.write_only = write_only
@ -113,7 +123,14 @@ class Field(object):
self.initial = initial self.initial = initial
self.label = label self.label = label
self.style = {} if style is None else style self.style = {} if style is None else style
self.validators = self.default_validators + validators self.validators = validators or self.default_validators[:]
# Collect default error message from self and parent classes
messages = {}
for cls in reversed(self.__class__.__mro__):
messages.update(getattr(cls, 'default_error_messages', {}))
messages.update(error_messages or {})
self.error_messages = messages
def bind(self, field_name, parent, root): def bind(self, field_name, parent, root):
""" """
@ -186,12 +203,14 @@ class Field(object):
self.fail('required') self.fail('required')
return self.get_default() return self.get_default()
self.run_validators(data) value = self.to_native(data)
return self.to_native(data) self.run_validators(value)
return value
def run_validators(self, value): def run_validators(self, value):
if value in validators.EMPTY_VALUES: if value in validators.EMPTY_VALUES:
return return
errors = [] errors = []
for validator in self.validators: for validator in self.validators:
try: try:
@ -218,33 +237,32 @@ class Field(object):
A helper method that simply raises a validation error. A helper method that simply raises a validation error.
""" """
try: try:
raise ValidationError(self.MESSAGES[key].format(**kwargs)) msg = self.error_messages[key]
except KeyError: except KeyError:
class_name = self.__class__.__name__ class_name = self.__class__.__name__
msg = self._MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key) msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key)
raise AssertionError(msg) raise AssertionError(msg)
raise ValidationError(msg.format(**kwargs))
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
"""
When a field is instantiated, we store the arguments that were used,
so that we can present a helpful representation of the object.
"""
instance = super(Field, cls).__new__(cls) instance = super(Field, cls).__new__(cls)
instance._args = args instance._args = args
instance._kwargs = kwargs instance._kwargs = kwargs
return instance return instance
def __repr__(self): def __repr__(self):
arg_string = ', '.join([repr(val) for val in self._args]) return representation.field_repr(self)
kwarg_string = ', '.join([
'%s=%s' % (key, repr(val)) for key, val in self._kwargs.items()
])
if arg_string and kwarg_string:
arg_string += ', '
class_name = self.__class__.__name__
return "%s(%s%s)" % (class_name, arg_string, kwarg_string)
# Boolean types...
class BooleanField(Field): class BooleanField(Field):
MESSAGES = { default_error_messages = {
'required': 'This field is required.', 'invalid': _('`{input}` is not a valid boolean.')
'invalid_value': '`{input}` is not a valid boolean.'
} }
TRUE_VALUES = {'t', 'T', 'true', 'True', 'TRUE', '1', 1, True} TRUE_VALUES = {'t', 'T', 'true', 'True', 'TRUE', '1', 1, True}
FALSE_VALUES = {'f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False} FALSE_VALUES = {'f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False}
@ -261,13 +279,23 @@ class BooleanField(Field):
return True return True
elif data in self.FALSE_VALUES: elif data in self.FALSE_VALUES:
return False return False
self.fail('invalid_value', input=data) self.fail('invalid', input=data)
def to_primative(self, value):
if value is None:
return None
if value in self.TRUE_VALUES:
return True
elif value in self.FALSE_VALUES:
return False
return bool(value)
# String types...
class CharField(Field): class CharField(Field):
MESSAGES = { default_error_messages = {
'required': 'This field is required.', 'blank': _('This field may not be blank.')
'blank': 'This field may not be blank.'
} }
def __init__(self, **kwargs): def __init__(self, **kwargs):
@ -281,19 +309,364 @@ class CharField(Field):
self.fail('blank') self.fail('blank')
return str(data) return str(data)
def to_primative(self, value):
if value is None:
return None
return str(value)
class ChoiceField(Field):
MESSAGES = { class EmailField(CharField):
'required': 'This field is required.', default_error_messages = {
'invalid_choice': '`{input}` is not a valid choice.' 'invalid': _('Enter a valid email address.')
}
default_validators = [validators.validate_email]
def to_native(self, data):
ret = super(EmailField, self).to_native(data)
if ret is None:
return None
return ret.strip()
def to_primative(self, value):
ret = super(EmailField, self).to_primative(value)
if ret is None:
return None
return ret.strip()
class RegexField(CharField):
def __init__(self, regex, **kwargs):
kwargs['validators'] = (
[validators.RegexValidator(regex)] +
kwargs.get('validators', [])
)
super(RegexField, self).__init__(**kwargs)
class SlugField(CharField):
default_error_messages = {
'invalid': _("Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens.")
}
default_validators = [validators.validate_slug]
class URLField(CharField):
default_error_messages = {
'invalid': _("Enter a valid URL.")
}
default_validators = [validators.URLValidator()]
# Number types...
class IntegerField(Field):
default_error_messages = {
'invalid': _('A valid integer is required.')
} }
coerce_to_type = str
def __init__(self, **kwargs): def __init__(self, **kwargs):
choices = kwargs.pop('choices') max_value = kwargs.pop('max_value', None)
min_value = kwargs.pop('min_value', None)
super(IntegerField, self).__init__(**kwargs)
if max_value is not None:
self.validators.append(validators.MaxValueValidator(max_value))
if min_value is not None:
self.validators.append(validators.MinValueValidator(min_value))
print self.__class__.__name__, self.validators
assert choices, '`choices` argument is required and may not be empty' def to_native(self, data):
try:
data = int(str(data))
except (ValueError, TypeError):
self.fail('invalid')
return data
def to_primative(self, value):
if value is None:
return None
return int(value)
class FloatField(Field):
default_error_messages = {
'invalid': _("'%s' value must be a float."),
}
def __init__(self, **kwargs):
max_value = kwargs.pop('max_value', None)
min_value = kwargs.pop('min_value', None)
super(FloatField, self).__init__(**kwargs)
if max_value is not None:
self.validators.append(validators.MaxValueValidator(max_value))
if min_value is not None:
self.validators.append(validators.MinValueValidator(min_value))
def to_primative(self, value):
if value is None:
return None
try:
return float(value)
except (TypeError, ValueError):
self.fail('invalid', value=value)
def to_native(self, value):
if value is None:
return None
return float(value)
class DecimalField(Field):
default_error_messages = {
'invalid': _('Enter a number.'),
'max_value': _('Ensure this value is less than or equal to {max_value}.'),
'min_value': _('Ensure this value is greater than or equal to {min_value}.'),
'max_digits': _('Ensure that there are no more than {max_digits} digits in total.'),
'max_decimal_places': _('Ensure that there are no more than {max_decimal_places} decimal places.'),
'max_whole_digits': _('Ensure that there are no more than {max_whole_digits} digits before the decimal point.')
}
def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, **kwargs):
self.max_value, self.min_value = max_value, min_value
self.max_digits, self.max_decimal_places = max_digits, decimal_places
super(DecimalField, self).__init__(**kwargs)
if max_value is not None:
self.validators.append(validators.MaxValueValidator(max_value))
if min_value is not None:
self.validators.append(validators.MinValueValidator(min_value))
def from_native(self, value):
"""
Validates that the input is a decimal number. Returns a Decimal
instance. Returns None for empty values. Ensures that there are no more
than max_digits in the number, and no more than decimal_places digits
after the decimal point.
"""
if value in validators.EMPTY_VALUES:
return None
value = smart_text(value).strip()
try:
value = decimal.Decimal(value)
except decimal.DecimalException:
self.fail('invalid')
# Check for NaN. It is the only value that isn't equal to itself,
# so we can use this to identify NaN values.
if value != value:
self.fail('invalid')
# Check for infinity and negative infinity.
if value in (decimal.Decimal('Inf'), decimal.Decimal('-Inf')):
self.fail('invalid')
sign, digittuple, exponent = value.as_tuple()
decimals = abs(exponent)
# digittuple doesn't include any leading zeros.
digits = len(digittuple)
if decimals > digits:
# We have leading zeros up to or past the decimal point. Count
# everything past the decimal point as a digit. We do not count
# 0 before the decimal point as a digit since that would mean
# we would not allow max_digits = decimal_places.
digits = decimals
whole_digits = digits - decimals
if self.max_digits is not None and digits > self.max_digits:
self.fail('max_digits', max_digits=self.max_digits)
if self.decimal_places is not None and decimals > self.decimal_places:
self.fail('max_decimal_places', max_decimal_places=self.max_decimal_places)
if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places):
self.fail('max_whole_digits', max_while_digits=self.max_digits - self.decimal_places)
return value
# Date & time fields...
class DateField(Field):
default_error_messages = {
'invalid': _("Date has wrong format. Use one of these formats instead: %s"),
}
input_formats = api_settings.DATE_INPUT_FORMATS
format = api_settings.DATE_FORMAT
def __init__(self, input_formats=None, format=None, *args, **kwargs):
self.input_formats = input_formats if input_formats is not None else self.input_formats
self.format = format if format is not None else self.format
super(DateField, self).__init__(*args, **kwargs)
def from_native(self, value):
if value in validators.EMPTY_VALUES:
return None
if isinstance(value, datetime.datetime):
if timezone and settings.USE_TZ and timezone.is_aware(value):
# Convert aware datetimes to the default time zone
# before casting them to dates (#17742).
default_timezone = timezone.get_default_timezone()
value = timezone.make_naive(value, default_timezone)
return value.date()
if isinstance(value, datetime.date):
return value
for format in self.input_formats:
if format.lower() == ISO_8601:
try:
parsed = parse_date(value)
except (ValueError, TypeError):
pass
else:
if parsed is not None:
return parsed
else:
try:
parsed = datetime.datetime.strptime(value, format)
except (ValueError, TypeError):
pass
else:
return parsed.date()
humanized_format = humanize_datetime.date_formats(self.input_formats)
msg = self.error_messages['invalid'] % humanized_format
raise ValidationError(msg)
def to_primative(self, value):
if value is None or self.format is None:
return value
if isinstance(value, datetime.datetime):
value = value.date()
if self.format.lower() == ISO_8601:
return value.isoformat()
return value.strftime(self.format)
class DateTimeField(Field):
default_error_messages = {
'invalid': _("Datetime has wrong format. Use one of these formats instead: %s"),
}
input_formats = api_settings.DATETIME_INPUT_FORMATS
format = api_settings.DATETIME_FORMAT
def __init__(self, input_formats=None, format=None, *args, **kwargs):
self.input_formats = input_formats if input_formats is not None else self.input_formats
self.format = format if format is not None else self.format
super(DateTimeField, self).__init__(*args, **kwargs)
def from_native(self, value):
if value in validators.EMPTY_VALUES:
return None
if isinstance(value, datetime.datetime):
return value
if isinstance(value, datetime.date):
value = datetime.datetime(value.year, value.month, value.day)
if settings.USE_TZ:
# For backwards compatibility, interpret naive datetimes in
# local time. This won't work during DST change, but we can't
# do much about it, so we let the exceptions percolate up the
# call stack.
warnings.warn("DateTimeField received a naive datetime (%s)"
" while time zone support is active." % value,
RuntimeWarning)
default_timezone = timezone.get_default_timezone()
value = timezone.make_aware(value, default_timezone)
return value
for format in self.input_formats:
if format.lower() == ISO_8601:
try:
parsed = parse_datetime(value)
except (ValueError, TypeError):
pass
else:
if parsed is not None:
return parsed
else:
try:
parsed = datetime.datetime.strptime(value, format)
except (ValueError, TypeError):
pass
else:
return parsed
humanized_format = humanize_datetime.datetime_formats(self.input_formats)
msg = self.error_messages['invalid'] % humanized_format
raise ValidationError(msg)
def to_primative(self, value):
if value is None or self.format is None:
return value
if self.format.lower() == ISO_8601:
ret = value.isoformat()
if ret.endswith('+00:00'):
ret = ret[:-6] + 'Z'
return ret
return value.strftime(self.format)
class TimeField(Field):
default_error_messages = {
'invalid': _("Time has wrong format. Use one of these formats instead: %s"),
}
input_formats = api_settings.TIME_INPUT_FORMATS
format = api_settings.TIME_FORMAT
def __init__(self, input_formats=None, format=None, *args, **kwargs):
self.input_formats = input_formats if input_formats is not None else self.input_formats
self.format = format if format is not None else self.format
super(TimeField, self).__init__(*args, **kwargs)
def from_native(self, value):
if value in validators.EMPTY_VALUES:
return None
if isinstance(value, datetime.time):
return value
for format in self.input_formats:
if format.lower() == ISO_8601:
try:
parsed = parse_time(value)
except (ValueError, TypeError):
pass
else:
if parsed is not None:
return parsed
else:
try:
parsed = datetime.datetime.strptime(value, format)
except (ValueError, TypeError):
pass
else:
return parsed.time()
humanized_format = humanize_datetime.time_formats(self.input_formats)
msg = self.error_messages['invalid'] % humanized_format
raise ValidationError(msg)
def to_primative(self, value):
if value is None or self.format is None:
return value
if isinstance(value, datetime.datetime):
value = value.time()
if self.format.lower() == ISO_8601:
return value.isoformat()
return value.strftime(self.format)
# Choice types...
class ChoiceField(Field):
default_error_messages = {
'invalid_choice': _('`{input}` is not a valid choice.')
}
def __init__(self, choices, **kwargs):
# Allow either single or paired choices style: # Allow either single or paired choices style:
# choices = [1, 2, 3] # choices = [1, 2, 3]
# choices = [(1, 'First'), (2, 'Second'), (3, 'Third')] # choices = [(1, 'First'), (2, 'Second'), (3, 'Third')]
@ -321,12 +694,14 @@ class ChoiceField(Field):
except KeyError: except KeyError:
self.fail('invalid_choice', input=data) self.fail('invalid_choice', input=data)
def to_primative(self, value):
return value
class MultipleChoiceField(ChoiceField): class MultipleChoiceField(ChoiceField):
MESSAGES = { default_error_messages = {
'required': 'This field is required.', 'invalid_choice': _('`{input}` is not a valid choice.'),
'invalid_choice': '`{input}` is not a valid choice.', 'not_a_list': _('Expected a list of items but got type `{input_type}`')
'not_a_list': 'Expected a list of items but got type `{input_type}`'
} }
def to_native(self, data): def to_native(self, data):
@ -337,72 +712,42 @@ class MultipleChoiceField(ChoiceField):
for item in data for item in data
]) ])
class IntegerField(Field):
MESSAGES = {
'required': 'This field is required.',
'invalid_integer': 'A valid integer is required.'
}
def __init__(self, **kwargs):
max_value = kwargs.pop('max_value', None)
min_value = kwargs.pop('min_value', None)
super(IntegerField, self).__init__(**kwargs)
if max_value is not None:
self.validators.append(validators.MaxValueValidator(max_value))
if min_value is not None:
self.validators.append(validators.MinValueValidator(min_value))
def to_native(self, data):
try:
data = int(str(data))
except (ValueError, TypeError):
self.fail('invalid_integer')
return data
def to_primative(self, value): def to_primative(self, value):
if value is None: return value
return None
return int(value)
class EmailField(CharField): # File types...
pass # TODO
class URLField(CharField):
pass # TODO
class RegexField(CharField):
def __init__(self, **kwargs):
self.regex = kwargs.pop('regex')
super(CharField, self).__init__(**kwargs)
class DateField(CharField):
def __init__(self, **kwargs):
self.input_formats = kwargs.pop('input_formats', None)
super(DateField, self).__init__(**kwargs)
class TimeField(CharField):
def __init__(self, **kwargs):
self.input_formats = kwargs.pop('input_formats', None)
super(TimeField, self).__init__(**kwargs)
class DateTimeField(CharField):
def __init__(self, **kwargs):
self.input_formats = kwargs.pop('input_formats', None)
super(DateTimeField, self).__init__(**kwargs)
class FileField(Field): class FileField(Field):
pass # TODO pass # TODO
class ImageField(Field):
pass # TODO
# Advanced field types...
class ReadOnlyField(Field): class ReadOnlyField(Field):
"""
A read-only field that simply returns the field value.
If the field is a method with no parameters, the method will be called
and it's return value used as the representation.
For example, the following would call `get_expiry_date()` on the object:
class ExampleSerializer(self):
expiry_date = ReadOnlyField(source='get_expiry_date')
"""
def __init__(self, **kwargs):
kwargs['read_only'] = True
super(ReadOnlyField, self).__init__(**kwargs)
def to_native(self, data):
raise NotImplemented('.to_native() not supported.')
def to_primative(self, value): def to_primative(self, value):
if is_simple_callable(value): if is_simple_callable(value):
return value() return value()
@ -410,11 +755,28 @@ class ReadOnlyField(Field):
class MethodField(Field): class MethodField(Field):
"""
A read-only field that get its representation from calling a method on the
parent serializer class. The method called will be of the form
"get_{field_name}", and should take a single argument, which is the
object being serialized.
For example:
class ExampleSerializer(self):
extra_info = MethodField()
def get_extra_info(self, obj):
return ... # Calculate some data to return.
"""
def __init__(self, **kwargs): def __init__(self, **kwargs):
kwargs['source'] = '*' kwargs['source'] = '*'
kwargs['read_only'] = True kwargs['read_only'] = True
super(MethodField, self).__init__(**kwargs) super(MethodField, self).__init__(**kwargs)
def to_native(self, data):
raise NotImplemented('.to_native() not supported.')
def to_primative(self, value): def to_primative(self, value):
attr = 'get_{field_name}'.format(field_name=self.field_name) attr = 'get_{field_name}'.format(field_name=self.field_name)
method = getattr(self.parent, attr) method = getattr(self.parent, attr)
@ -424,35 +786,14 @@ class MethodField(Field):
class ModelField(Field): class ModelField(Field):
""" """
A generic field that can be used against an arbitrary model field. A generic field that can be used against an arbitrary model field.
This is used by `ModelSerializer` when dealing with custom model fields,
that do not have a serializer field to be mapped to.
""" """
def __init__(self, *args, **kwargs): def __init__(self, model_field, **kwargs):
try: self.model_field = model_field
self.model_field = kwargs.pop('model_field') kwargs['source'] = '*'
except KeyError: super(ModelField, self).__init__(**kwargs)
raise ValueError("ModelField requires 'model_field' kwarg")
self.min_length = kwargs.pop('min_length',
getattr(self.model_field, 'min_length', None))
self.max_length = kwargs.pop('max_length',
getattr(self.model_field, 'max_length', None))
self.min_value = kwargs.pop('min_value',
getattr(self.model_field, 'min_value', None))
self.max_value = kwargs.pop('max_value',
getattr(self.model_field, 'max_value', None))
super(ModelField, self).__init__(*args, **kwargs)
if self.min_length is not None:
self.validators.append(validators.MinLengthValidator(self.min_length))
if self.max_length is not None:
self.validators.append(validators.MaxLengthValidator(self.max_length))
if self.min_value is not None:
self.validators.append(validators.MinValueValidator(self.min_value))
if self.max_value is not None:
self.validators.append(validators.MaxValueValidator(self.max_value))
def get_attribute(self, instance):
return get_attribute(instance, self.source_attrs[:-1])
def to_native(self, data): def to_native(self, data):
rel = getattr(self.model_field, 'rel', None) rel = getattr(self.model_field, 'rel', None)

View File

@ -10,15 +10,15 @@ python primitives.
2. The process of marshalling between python primitives and request and 2. The process of marshalling between python primitives and request and
response content is handled by parsers and renderers. response content is handled by parsers and renderers.
""" """
from django.core import validators
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.db import models from django.db import models
from django.utils import six from django.utils import six
from collections import namedtuple, OrderedDict from collections import namedtuple, OrderedDict
from rest_framework.fields import empty, set_value, Field, SkipField from rest_framework.fields import empty, set_value, Field, SkipField
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.utils import html from rest_framework.utils import html, modelinfo, representation
import copy import copy
import inspect
# Note: We do the following so that users of the framework can use this style: # Note: We do the following so that users of the framework can use this style:
# #
@ -146,12 +146,10 @@ class SerializerMetaclass(type):
class Serializer(BaseSerializer): class Serializer(BaseSerializer):
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
many = kwargs.pop('many', False) if kwargs.pop('many', False):
if many: kwargs['child'] = cls()
class DynamicListSerializer(ListSerializer): return ListSerializer(*args, **kwargs)
child = cls() return super(Serializer, cls).__new__(cls, *args, **kwargs)
return DynamicListSerializer(*args, **kwargs)
return super(Serializer, cls).__new__(cls)
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.context = kwargs.pop('context', {}) self.context = kwargs.pop('context', {})
@ -248,6 +246,9 @@ class Serializer(BaseSerializer):
error = errors.get(field.field_name) error = errors.get(field.field_name)
yield FieldResult(field, value, error) yield FieldResult(field, value, error)
def __repr__(self):
return representation.serializer_repr(self, indent=1)
class ListSerializer(BaseSerializer): class ListSerializer(BaseSerializer):
child = None child = None
@ -299,26 +300,8 @@ class ListSerializer(BaseSerializer):
self.instance = self.create(self.validated_data) self.instance = self.create(self.validated_data)
return self.instance return self.instance
def __repr__(self):
def _resolve_model(obj): return representation.list_repr(self, indent=1)
"""
Resolve supplied `obj` to a Django model class.
`obj` must be a Django model class itself, or a string
representation of one. Useful in situtations like GH #1225 where
Django may not have resolved a string-based reference to a model in
another model's foreign key definition.
String representations should have the format:
'appname.ModelName'
"""
if isinstance(obj, six.string_types) and len(obj.split('.')) == 2:
app_name, model_name = obj.split('.')
return models.get_model(app_name, model_name)
elif inspect.isclass(obj) and issubclass(obj, models.Model):
return obj
else:
raise ValueError("{0} is not a Django model".format(obj))
class ModelSerializerOptions(object): class ModelSerializerOptions(object):
@ -334,24 +317,25 @@ class ModelSerializerOptions(object):
class ModelSerializer(Serializer): class ModelSerializer(Serializer):
field_mapping = { field_mapping = {
models.AutoField: IntegerField, models.AutoField: IntegerField,
# models.FloatField: FloatField, models.BigIntegerField: IntegerField,
models.IntegerField: IntegerField,
models.PositiveIntegerField: IntegerField,
models.SmallIntegerField: IntegerField,
models.PositiveSmallIntegerField: IntegerField,
models.DateTimeField: DateTimeField,
models.DateField: DateField,
models.TimeField: TimeField,
# models.DecimalField: DecimalField,
models.EmailField: EmailField,
models.CharField: CharField,
models.URLField: URLField,
# models.SlugField: SlugField,
models.TextField: CharField,
models.CommaSeparatedIntegerField: CharField,
models.BooleanField: BooleanField, models.BooleanField: BooleanField,
models.NullBooleanField: BooleanField, models.CharField: CharField,
models.CommaSeparatedIntegerField: CharField,
models.DateField: DateField,
models.DateTimeField: DateTimeField,
models.DecimalField: DecimalField,
models.EmailField: EmailField,
models.FileField: FileField, models.FileField: FileField,
models.FloatField: FloatField,
models.IntegerField: IntegerField,
models.NullBooleanField: BooleanField,
models.PositiveIntegerField: IntegerField,
models.PositiveSmallIntegerField: IntegerField,
models.SlugField: SlugField,
models.SmallIntegerField: IntegerField,
models.TextField: CharField,
models.TimeField: TimeField,
models.URLField: URLField,
# models.ImageField: ImageField, # models.ImageField: ImageField,
} }
@ -392,85 +376,31 @@ class ModelSerializer(Serializer):
""" """
Return all the fields that should be serialized for the model. Return all the fields that should be serialized for the model.
""" """
cls = self.opts.model info = modelinfo.get_field_info(self.opts.model)
opts = cls._meta.concrete_model._meta
ret = OrderedDict() ret = OrderedDict()
nested = bool(self.opts.depth)
# Deal with adding the primary key field serializer_pk_field = self.get_pk_field(info.pk)
pk_field = opts.pk
while pk_field.rel and pk_field.rel.parent_link:
# If model is a child via multitable inheritance, use parent's pk
pk_field = pk_field.rel.to._meta.pk
serializer_pk_field = self.get_pk_field(pk_field)
if serializer_pk_field: if serializer_pk_field:
ret[pk_field.name] = serializer_pk_field ret[info.pk.name] = serializer_pk_field
# Deal with forward relationships # Regular fields
forward_rels = [field for field in opts.fields if field.serialize] for field_name, field in info.fields.items():
forward_rels += [field for field in opts.many_to_many if field.serialize] ret[field_name] = self.get_field(field)
for model_field in forward_rels: # Forward relations
has_through_model = False for field_name, relation_info in info.forward_relations.items():
if self.opts.depth:
if model_field.rel: ret[field_name] = self.get_nested_field(*relation_info)
to_many = isinstance(model_field,
models.fields.related.ManyToManyField)
related_model = _resolve_model(model_field.rel.to)
if to_many and not model_field.rel.through._meta.auto_created:
has_through_model = True
if model_field.rel and nested:
field = self.get_nested_field(model_field, related_model, to_many)
elif model_field.rel:
field = self.get_related_field(model_field, related_model, to_many)
else: else:
field = self.get_field(model_field) ret[field_name] = self.get_related_field(*relation_info)
if field: # Reverse relations
if has_through_model: for accessor_name, relation_info in info.reverse_relations.items():
field.read_only = True if accessor_name in self.opts.fields:
if self.opts.depth:
ret[model_field.name] = field ret[field_name] = self.get_nested_field(*relation_info)
# Deal with reverse relationships
if not self.opts.fields:
reverse_rels = []
else: else:
# Reverse relationships are only included if they are explicitly ret[field_name] = self.get_related_field(*relation_info)
# present in the `fields` option on the serializer
reverse_rels = opts.get_all_related_objects()
reverse_rels += opts.get_all_related_many_to_many_objects()
for relation in reverse_rels:
accessor_name = relation.get_accessor_name()
if not self.opts.fields or accessor_name not in self.opts.fields:
continue
related_model = relation.model
to_many = relation.field.rel.multiple
has_through_model = False
is_m2m = isinstance(relation.field,
models.fields.related.ManyToManyField)
if (
is_m2m and
hasattr(relation.field.rel, 'through') and
not relation.field.rel.through._meta.auto_created
):
has_through_model = True
if nested:
field = self.get_nested_field(None, related_model, to_many)
else:
field = self.get_related_field(None, related_model, to_many)
if field:
if has_through_model:
field.read_only = True
ret[accessor_name] = field
return ret return ret
@ -480,7 +410,7 @@ class ModelSerializer(Serializer):
""" """
return self.get_field(model_field) return self.get_field(model_field)
def get_nested_field(self, model_field, related_model, to_many): def get_nested_field(self, model_field, related_model, to_many, has_through_model):
""" """
Creates a default instance of a nested relational field. Creates a default instance of a nested relational field.
@ -491,59 +421,148 @@ class ModelSerializer(Serializer):
model = related_model model = related_model
depth = self.opts.depth - 1 depth = self.opts.depth - 1
return NestedModelSerializer(many=to_many) kwargs = {'read_only': True}
if to_many:
kwargs['many'] = True
return NestedModelSerializer(**kwargs)
def get_related_field(self, model_field, related_model, to_many): def get_related_field(self, model_field, related_model, to_many, has_through_model):
""" """
Creates a default instance of a flat relational field. Creates a default instance of a flat relational field.
Note that model_field will be `None` for reverse relationships. Note that model_field will be `None` for reverse relationships.
""" """
# TODO: filter queryset using: kwargs = {
# .using(db).complex_filter(self.rel.limit_choices_to) 'queryset': related_model._default_manager,
}
kwargs = {} if to_many:
# 'queryset': related_model._default_manager, kwargs['many'] = True
# 'many': to_many
# } if has_through_model:
kwargs['read_only'] = True
kwargs.pop('queryset', None)
if model_field: if model_field:
kwargs['required'] = not(model_field.null or model_field.blank) if model_field.null or model_field.blank:
kwargs['required'] = False
# if model_field.help_text is not None: # if model_field.help_text is not None:
# kwargs['help_text'] = model_field.help_text # kwargs['help_text'] = model_field.help_text
if model_field.verbose_name is not None: if model_field.verbose_name is not None:
kwargs['label'] = model_field.verbose_name kwargs['label'] = model_field.verbose_name
if not model_field.editable: if not model_field.editable:
kwargs['read_only'] = True kwargs['read_only'] = True
if model_field.verbose_name is not None: kwargs.pop('queryset', None)
kwargs['label'] = model_field.verbose_name
return IntegerField(**kwargs) return PrimaryKeyRelatedField(**kwargs)
# TODO: return PrimaryKeyRelatedField(**kwargs)
def get_field(self, model_field): def get_field(self, model_field):
""" """
Creates a default instance of a basic non-relational field. Creates a default instance of a basic non-relational field.
""" """
kwargs = {} kwargs = {}
validator_kwarg = model_field.validators
if model_field.null or model_field.blank: if model_field.null or model_field.blank:
kwargs['required'] = False kwargs['required'] = False
if isinstance(model_field, models.AutoField) or not model_field.editable:
kwargs['read_only'] = True
if model_field.has_default():
kwargs['default'] = model_field.get_default()
if issubclass(model_field.__class__, models.TextField):
kwargs['widget'] = widgets.Textarea
if model_field.verbose_name is not None: if model_field.verbose_name is not None:
kwargs['label'] = model_field.verbose_name kwargs['label'] = model_field.verbose_name
if model_field.validators is not None: if isinstance(model_field, models.AutoField) or not model_field.editable:
kwargs['validators'] = model_field.validators kwargs['read_only'] = True
# Read only implies that the field is not required.
# We have a cleaner repr on the instance if we don't set it.
kwargs.pop('required', None)
if model_field.has_default():
kwargs['default'] = model_field.get_default()
# Having a default implies that the field is not required.
# We have a cleaner repr on the instance if we don't set it.
kwargs.pop('required', None)
# Ensure that max_length is passed explicitly as a keyword arg,
# rather than as a validator.
max_length = getattr(model_field, 'max_length', None)
if max_length is not None:
kwargs['max_length'] = max_length
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.MaxLengthValidator)
]
# Ensure that min_length is passed explicitly as a keyword arg,
# rather than as a validator.
min_length = getattr(model_field, 'min_length', None)
if min_length is not None:
kwargs['min_length'] = min_length
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.MinLengthValidator)
]
# Ensure that max_value is passed explicitly as a keyword arg,
# rather than as a validator.
max_value = next((
validator.limit_value for validator in validator_kwarg
if isinstance(validator, validators.MaxValueValidator)
), None)
if max_value is not None:
kwargs['max_value'] = max_value
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.MaxValueValidator)
]
# Ensure that max_value is passed explicitly as a keyword arg,
# rather than as a validator.
min_value = next((
validator.limit_value for validator in validator_kwarg
if isinstance(validator, validators.MinValueValidator)
), None)
if min_value is not None:
kwargs['min_value'] = min_value
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.MinValueValidator)
]
# URLField does not need to include the URLValidator argument,
# as it is explicitly added in.
if isinstance(model_field, models.URLField):
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.URLValidator)
]
# EmailField does not need to include the validate_email argument,
# as it is explicitly added in.
if isinstance(model_field, models.EmailField):
validator_kwarg = [
validator for validator in validator_kwarg
if validator is not validators.validate_email
]
# SlugField do not need to include the 'validate_slug' argument,
if isinstance(model_field, models.SlugField):
validator_kwarg = [
validator for validator in validator_kwarg
if validator is not validators.validate_slug
]
max_digits = getattr(model_field, 'max_digits', None)
if max_digits is not None:
kwargs['max_digits'] = max_digits
decimal_places = getattr(model_field, 'decimal_places', None)
if decimal_places is not None:
kwargs['decimal_places'] = decimal_places
if validator_kwarg:
kwargs['validators'] = validator_kwarg
# if issubclass(model_field.__class__, models.TextField):
# kwargs['widget'] = widgets.Textarea
# if model_field.help_text is not None: # if model_field.help_text is not None:
# kwargs['help_text'] = model_field.help_text # kwargs['help_text'] = model_field.help_text
@ -555,31 +574,10 @@ class ModelSerializer(Serializer):
kwargs['empty'] = None kwargs['empty'] = None
return ChoiceField(**kwargs) return ChoiceField(**kwargs)
# put this below the ChoiceField because min_value isn't a valid initializer
if issubclass(model_field.__class__, models.PositiveIntegerField) or \
issubclass(model_field.__class__, models.PositiveSmallIntegerField):
kwargs['min_value'] = 0
if model_field.null and \ if model_field.null and \
issubclass(model_field.__class__, (models.CharField, models.TextField)): issubclass(model_field.__class__, (models.CharField, models.TextField)):
kwargs['allow_none'] = True kwargs['allow_none'] = True
# attribute_dict = {
# models.CharField: ['max_length'],
# models.CommaSeparatedIntegerField: ['max_length'],
# models.DecimalField: ['max_digits', 'decimal_places'],
# models.EmailField: ['max_length'],
# models.FileField: ['max_length'],
# models.ImageField: ['max_length'],
# models.SlugField: ['max_length'],
# models.URLField: ['max_length'],
# }
# if model_field.__class__ in attribute_dict:
# attributes = attribute_dict[model_field.__class__]
# for attribute in attributes:
# kwargs.update({attribute: getattr(model_field, attribute)})
try: try:
return self.field_mapping[model_field.__class__](**kwargs) return self.field_mapping[model_field.__class__](**kwargs)
except KeyError: except KeyError:
@ -594,28 +592,21 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
super(HyperlinkedModelSerializerOptions, self).__init__(meta) super(HyperlinkedModelSerializerOptions, self).__init__(meta)
self.view_name = getattr(meta, 'view_name', None) self.view_name = getattr(meta, 'view_name', None)
self.lookup_field = getattr(meta, 'lookup_field', None) self.lookup_field = getattr(meta, 'lookup_field', None)
self.url_field_name = getattr(meta, 'url_field_name', api_settings.URL_FIELD_NAME)
class HyperlinkedModelSerializer(ModelSerializer): class HyperlinkedModelSerializer(ModelSerializer):
_options_class = HyperlinkedModelSerializerOptions _options_class = HyperlinkedModelSerializerOptions
_default_view_name = '%(model_name)s-detail'
_hyperlink_field_class = HyperlinkedRelatedField
_hyperlink_identify_field_class = HyperlinkedIdentityField
def get_default_fields(self): def get_default_fields(self):
fields = super(HyperlinkedModelSerializer, self).get_default_fields() fields = super(HyperlinkedModelSerializer, self).get_default_fields()
if self.opts.view_name is None: if self.opts.view_name is None:
self.opts.view_name = self._get_default_view_name(self.opts.model) self.opts.view_name = self.get_default_view_name(self.opts.model)
if self.opts.url_field_name not in fields: url_field_name = api_settings.URL_FIELD_NAME
url_field = self._hyperlink_identify_field_class( if url_field_name not in fields:
view_name=self.opts.view_name,
lookup_field=self.opts.lookup_field
)
ret = fields.__class__() ret = fields.__class__()
ret[self.opts.url_field_name] = url_field ret[url_field_name] = self.get_url_field()
ret.update(fields) ret.update(fields)
fields = ret fields = ret
@ -625,39 +616,48 @@ class HyperlinkedModelSerializer(ModelSerializer):
if self.opts.fields and model_field.name in self.opts.fields: if self.opts.fields and model_field.name in self.opts.fields:
return self.get_field(model_field) return self.get_field(model_field)
def get_related_field(self, model_field, related_model, to_many): def get_url_field(self):
kwargs = {
'view_name': self.get_default_view_name(self.opts.model)
}
if self.opts.lookup_field:
kwargs['lookup_field'] = self.opts.lookup_field
return HyperlinkedIdentityField(**kwargs)
def get_related_field(self, model_field, related_model, to_many, has_through_model):
""" """
Creates a default instance of a flat relational field. Creates a default instance of a flat relational field.
""" """
# TODO: filter queryset using: kwargs = {
# .using(db).complex_filter(self.rel.limit_choices_to) 'queryset': related_model._default_manager,
# kwargs = { 'view_name': self.get_default_view_name(related_model),
# 'queryset': related_model._default_manager, }
# 'view_name': self._get_default_view_name(related_model),
# 'many': to_many if to_many:
# } kwargs['many'] = True
kwargs = {}
if has_through_model:
kwargs['read_only'] = True
kwargs.pop('queryset', None)
if model_field: if model_field:
kwargs['required'] = not(model_field.null or model_field.blank) if model_field.null or model_field.blank:
kwargs['required'] = False
# if model_field.help_text is not None: # if model_field.help_text is not None:
# kwargs['help_text'] = model_field.help_text # kwargs['help_text'] = model_field.help_text
if model_field.verbose_name is not None: if model_field.verbose_name is not None:
kwargs['label'] = model_field.verbose_name kwargs['label'] = model_field.verbose_name
if not model_field.editable:
kwargs['read_only'] = True
kwargs.pop('queryset', None)
return IntegerField(**kwargs) return HyperlinkedRelatedField(**kwargs)
# if self.opts.lookup_field:
# kwargs['lookup_field'] = self.opts.lookup_field
# return self._hyperlink_field_class(**kwargs) def get_default_view_name(self, model):
def _get_default_view_name(self, model):
""" """
Return the view name to use if 'view_name' is not specified in 'Meta' Return the view name to use for related models.
""" """
model_meta = model._meta return '%(model_name)s-detail' % {
format_kwargs = { 'app_label': model._meta.app_label,
'app_label': model_meta.app_label, 'model_name': model._meta.object_name.lower()
'model_name': model_meta.object_name.lower()
} }
return self._default_view_name % format_kwargs

View File

@ -0,0 +1,47 @@
"""
Helper functions that convert strftime formats into more readable representations.
"""
from rest_framework import ISO_8601
def datetime_formats(formats):
format = ', '.join(formats).replace(
ISO_8601,
'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]'
)
return humanize_strptime(format)
def date_formats(formats):
format = ', '.join(formats).replace(ISO_8601, 'YYYY[-MM[-DD]]')
return humanize_strptime(format)
def time_formats(formats):
format = ', '.join(formats).replace(ISO_8601, 'hh:mm[:ss[.uuuuuu]]')
return humanize_strptime(format)
def humanize_strptime(format_string):
# Note that we're missing some of the locale specific mappings that
# don't really make sense.
mapping = {
"%Y": "YYYY",
"%y": "YY",
"%m": "MM",
"%b": "[Jan-Dec]",
"%B": "[January-December]",
"%d": "DD",
"%H": "hh",
"%I": "hh", # Requires '%p' to differentiate from '%H'.
"%M": "mm",
"%S": "ss",
"%f": "uuuuuu",
"%a": "[Mon-Sun]",
"%A": "[Monday-Sunday]",
"%p": "[AM|PM]",
"%z": "[+HHMM|-HHMM]"
}
for key, val in mapping.items():
format_string = format_string.replace(key, val)
return format_string

View File

@ -0,0 +1,97 @@
"""
Helper functions for returning the field information that is associated
with a model class.
"""
from collections import namedtuple, OrderedDict
from django.db import models
from django.utils import six
import inspect
FieldInfo = namedtuple('FieldResult', ['pk', 'fields', 'forward_relations', 'reverse_relations'])
RelationInfo = namedtuple('RelationInfo', ['field', 'related', 'to_many', 'has_through_model'])
def _resolve_model(obj):
"""
Resolve supplied `obj` to a Django model class.
`obj` must be a Django model class itself, or a string
representation of one. Useful in situtations like GH #1225 where
Django may not have resolved a string-based reference to a model in
another model's foreign key definition.
String representations should have the format:
'appname.ModelName'
"""
if isinstance(obj, six.string_types) and len(obj.split('.')) == 2:
app_name, model_name = obj.split('.')
return models.get_model(app_name, model_name)
elif inspect.isclass(obj) and issubclass(obj, models.Model):
return obj
raise ValueError("{0} is not a Django model".format(obj))
def get_field_info(model):
"""
Given a model class, returns a `FieldInfo` instance containing metadata
about the various field types on the model.
"""
opts = model._meta.concrete_model._meta
# Deal with the primary key.
pk = opts.pk
while pk.rel and pk.rel.parent_link:
# If model is a child via multitable inheritance, use parent's pk.
pk = pk.rel.to._meta.pk
# Deal with regular fields.
fields = OrderedDict()
for field in [field for field in opts.fields if field.serialize and not field.rel]:
fields[field.name] = field
# Deal with forward relationships.
forward_relations = OrderedDict()
for field in [field for field in opts.fields if field.serialize and field.rel]:
forward_relations[field.name] = RelationInfo(
field=field,
related=_resolve_model(field.rel.to),
to_many=False,
has_through_model=False
)
# Deal with forward many-to-many relationships.
for field in [field for field in opts.many_to_many if field.serialize]:
forward_relations[field.name] = RelationInfo(
field=field,
related=_resolve_model(field.rel.to),
to_many=True,
has_through_model=(
not field.rel.through._meta.auto_created
)
)
# Deal with reverse relationships.
reverse_relations = OrderedDict()
for relation in opts.get_all_related_objects():
accessor_name = relation.get_accessor_name()
reverse_relations[accessor_name] = RelationInfo(
field=None,
related=relation.model,
to_many=relation.field.rel.multiple,
has_through_model=False
)
# Deal with reverse many-to-many relationships.
for relation in opts.get_all_related_many_to_many_objects():
accessor_name = relation.get_accessor_name()
reverse_relations[accessor_name] = RelationInfo(
field=None,
related=relation.model,
to_many=True,
has_through_model=(
hasattr(relation.field.rel, 'through') and
not relation.field.rel.through._meta.auto_created
)
)
return FieldInfo(pk, fields, forward_relations, reverse_relations)

View File

@ -0,0 +1,72 @@
"""
Helper functions for creating user-friendly representations
of serializer classes and serializer fields.
"""
import re
def smart_repr(value):
value = repr(value)
# Representations like u'help text'
# should simply be presented as 'help text'
if value.startswith("u'") and value.endswith("'"):
return value[1:]
# Representations like
# <django.core.validators.RegexValidator object at 0x1047af050>
# Should be presented as
# <django.core.validators.RegexValidator object>
value = re.sub(' at 0x[0-9a-f]{8,10}>', '>', value)
return value
def field_repr(field, force_many=False):
kwargs = field._kwargs
if force_many:
kwargs = kwargs.copy()
kwargs['many'] = True
kwargs.pop('child', None)
arg_string = ', '.join([smart_repr(val) for val in field._args])
kwarg_string = ', '.join([
'%s=%s' % (key, smart_repr(val))
for key, val in sorted(kwargs.items())
])
if arg_string and kwarg_string:
arg_string += ', '
if force_many:
class_name = force_many.__class__.__name__
else:
class_name = field.__class__.__name__
return "%s(%s%s)" % (class_name, arg_string, kwarg_string)
def serializer_repr(serializer, indent, force_many=None):
ret = field_repr(serializer, force_many) + ':'
indent_str = ' ' * indent
if force_many:
fields = force_many.fields
else:
fields = serializer.fields
for field_name, field in fields.items():
ret += '\n' + indent_str + field_name + ' = '
if hasattr(field, 'fields'):
ret += serializer_repr(field, indent + 1)
elif hasattr(field, 'child'):
ret += list_repr(field, indent + 1)
else:
ret += field_repr(field)
return ret
def list_repr(serializer, indent):
child = serializer.child
if hasattr(child, 'fields'):
return serializer_repr(serializer, indent, force_many=child)
return field_repr(serializer)

View File

@ -0,0 +1,160 @@
"""
The `ModelSerializer` and `HyperlinkedModelSerializer` classes are essentially
shortcuts for automatically creating serializers based on a given model class.
These tests deal with ensuring that we correctly map the model fields onto
an appropriate set of serializer fields for each case.
"""
from django.db import models
from django.test import TestCase
from rest_framework import serializers
# Models for testing regular field mapping
class RegularFieldsModel(models.Model):
auto_field = models.AutoField(primary_key=True)
big_integer_field = models.BigIntegerField()
boolean_field = models.BooleanField()
char_field = models.CharField(max_length=100)
comma_seperated_integer_field = models.CommaSeparatedIntegerField(max_length=100)
date_field = models.DateField()
datetime_field = models.DateTimeField()
decimal_field = models.DecimalField(max_digits=3, decimal_places=1)
email_field = models.EmailField(max_length=100)
float_field = models.FloatField()
integer_field = models.IntegerField()
null_boolean_field = models.NullBooleanField()
positive_integer_field = models.PositiveIntegerField()
positive_small_integer_field = models.PositiveSmallIntegerField()
slug_field = models.SlugField(max_length=100)
small_integer_field = models.SmallIntegerField()
text_field = models.TextField()
time_field = models.TimeField()
url_field = models.URLField(max_length=100)
REGULAR_FIELDS_REPR = """
TestSerializer():
auto_field = IntegerField(label='auto field', read_only=True)
big_integer_field = IntegerField(label='big integer field')
boolean_field = BooleanField(default=False, label='boolean field')
char_field = CharField(label='char field', max_length=100)
comma_seperated_integer_field = CharField(label='comma seperated integer field', max_length=100, validators=[<django.core.validators.RegexValidator object>])
date_field = DateField(label='date field')
datetime_field = DateTimeField(label='datetime field')
decimal_field = DecimalField(decimal_places=1, label='decimal field', max_digits=3)
email_field = EmailField(label='email field', max_length=100)
float_field = FloatField(label='float field')
integer_field = IntegerField(label='integer field')
null_boolean_field = BooleanField(label='null boolean field', required=False)
positive_integer_field = IntegerField(label='positive integer field')
positive_small_integer_field = IntegerField(label='positive small integer field')
slug_field = SlugField(label='slug field', max_length=100)
small_integer_field = IntegerField(label='small integer field')
text_field = CharField(label='text field')
time_field = TimeField(label='time field')
url_field = URLField(label='url field', max_length=100)
""".strip()
# Model for testing relational field mapping
class ForeignKeyTarget(models.Model):
char_field = models.CharField(max_length=100)
class ManyToManyTarget(models.Model):
char_field = models.CharField(max_length=100)
class OneToOneTarget(models.Model):
char_field = models.CharField(max_length=100)
class RelationalModel(models.Model):
foreign_key = models.ForeignKey(ForeignKeyTarget)
many_to_many = models.ManyToManyField(ManyToManyTarget)
one_to_one = models.OneToOneField(OneToOneTarget)
RELATIONAL_FLAT_REPR = """
TestSerializer():
id = IntegerField(label='ID', read_only=True)
foreign_key = PrimaryKeyRelatedField(label='foreign key', queryset=<django.db.models.manager.Manager object>)
one_to_one = PrimaryKeyRelatedField(label='one to one', queryset=<django.db.models.manager.Manager object>)
many_to_many = PrimaryKeyRelatedField(label='many to many', many=True, queryset=<django.db.models.manager.Manager object>)
""".strip()
RELATIONAL_NESTED_REPR = """
TestSerializer():
id = IntegerField(label='ID', read_only=True)
foreign_key = NestedModelSerializer(read_only=True):
id = IntegerField(label='ID', read_only=True)
name = CharField(label='name', max_length=100)
one_to_one = NestedModelSerializer(read_only=True):
id = IntegerField(label='ID', read_only=True)
name = CharField(label='name', max_length=100)
many_to_many = NestedModelSerializer(many=True, read_only=True):
id = IntegerField(label='ID', read_only=True)
name = CharField(label='name', max_length=100)
""".strip()
HYPERLINKED_FLAT_REPR = """
TestSerializer():
url = HyperlinkedIdentityField(view_name='relationalmodel-detail')
foreign_key = HyperlinkedRelatedField(label='foreign key', queryset=<django.db.models.manager.Manager object>, view_name='foreignkeytarget-detail')
one_to_one = HyperlinkedRelatedField(label='one to one', queryset=<django.db.models.manager.Manager object>, view_name='onetoonetarget-detail')
many_to_many = HyperlinkedRelatedField(label='many to many', many=True, queryset=<django.db.models.manager.Manager object>, view_name='manytomanytarget-detail')
""".strip()
HYPERLINKED_NESTED_REPR = """
TestSerializer():
url = HyperlinkedIdentityField(view_name='relationalmodel-detail')
foreign_key = NestedModelSerializer(read_only=True):
id = IntegerField(label='ID', read_only=True)
name = CharField(label='name', max_length=100)
one_to_one = NestedModelSerializer(read_only=True):
id = IntegerField(label='ID', read_only=True)
name = CharField(label='name', max_length=100)
many_to_many = NestedModelSerializer(many=True, read_only=True):
id = IntegerField(label='ID', read_only=True)
name = CharField(label='name', max_length=100)
""".strip()
class TestSerializerMappings(TestCase):
def test_regular_fields(self):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = RegularFieldsModel
self.assertEqual(repr(TestSerializer()), REGULAR_FIELDS_REPR)
def test_flat_relational_fields(self):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = RelationalModel
self.assertEqual(repr(TestSerializer()), RELATIONAL_FLAT_REPR)
def test_nested_relational_fields(self):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = RelationalModel
depth = 1
self.assertEqual(repr(TestSerializer()), RELATIONAL_NESTED_REPR)
def test_flat_hyperlinked_fields(self):
class TestSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = RelationalModel
self.assertEqual(repr(TestSerializer()), HYPERLINKED_FLAT_REPR)
def test_nested_hyperlinked_fields(self):
class TestSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = RelationalModel
depth = 1
self.assertEqual(repr(TestSerializer()), HYPERLINKED_NESTED_REPR)

View File

@ -1,6 +1,6 @@
from django.test import TestCase from django.test import TestCase
from django.utils import six from django.utils import six
from rest_framework.serializers import _resolve_model from rest_framework.utils.modelinfo import _resolve_model
from tests.models import BasicModel from tests.models import BasicModel

View File

@ -22,18 +22,18 @@
# https://github.com/tomchristie/django-rest-framework/issues/446 # https://github.com/tomchristie/django-rest-framework/issues/446
# """ # """
# field = serializers.PrimaryKeyRelatedField(queryset=NullModel.objects.all()) # field = serializers.PrimaryKeyRelatedField(queryset=NullModel.objects.all())
# self.assertRaises(serializers.ValidationError, field.from_native, '') # self.assertRaises(serializers.ValidationError, field.to_primative, '')
# self.assertRaises(serializers.ValidationError, field.from_native, []) # self.assertRaises(serializers.ValidationError, field.to_primative, [])
# def test_hyperlinked_related_field_with_empty_string(self): # def test_hyperlinked_related_field_with_empty_string(self):
# field = serializers.HyperlinkedRelatedField(queryset=NullModel.objects.all(), view_name='') # field = serializers.HyperlinkedRelatedField(queryset=NullModel.objects.all(), view_name='')
# self.assertRaises(serializers.ValidationError, field.from_native, '') # self.assertRaises(serializers.ValidationError, field.to_primative, '')
# self.assertRaises(serializers.ValidationError, field.from_native, []) # self.assertRaises(serializers.ValidationError, field.to_primative, [])
# def test_slug_related_field_with_empty_string(self): # def test_slug_related_field_with_empty_string(self):
# field = serializers.SlugRelatedField(queryset=NullModel.objects.all(), slug_field='pk') # field = serializers.SlugRelatedField(queryset=NullModel.objects.all(), slug_field='pk')
# self.assertRaises(serializers.ValidationError, field.from_native, '') # self.assertRaises(serializers.ValidationError, field.to_primative, '')
# self.assertRaises(serializers.ValidationError, field.from_native, []) # self.assertRaises(serializers.ValidationError, field.to_primative, [])
# class TestManyRelatedMixin(TestCase): # class TestManyRelatedMixin(TestCase):

View File

@ -6,7 +6,7 @@
# def test_empty_serializer(self): # def test_empty_serializer(self):
# class FooBarSerializer(serializers.Serializer): # class FooBarSerializer(serializers.Serializer):
# foo = serializers.IntegerField() # foo = serializers.IntegerField()
# bar = serializers.SerializerMethodField('get_bar') # bar = serializers.MethodField()
# def get_bar(self, obj): # def get_bar(self, obj):
# return 'bar' # return 'bar'