Fleshing out serializer fields

This commit is contained in:
Tom Christie 2014-09-09 17:46:28 +01:00
parent 21980b800d
commit b1c07670ca
9 changed files with 1053 additions and 336 deletions

View File

@ -1,8 +1,18 @@
from django.conf import settings
from django.core import validators
from django.core.exceptions import ValidationError
from django.utils import timezone
from django.utils.dateparse import parse_date, parse_datetime, parse_time
from django.utils.encoding import is_protected_type
from rest_framework.utils import html
from django.utils.translation import ugettext_lazy as _
from rest_framework import ISO_8601
from rest_framework.compat import smart_text
from rest_framework.settings import api_settings
from rest_framework.utils import html, representation, humanize_datetime
import datetime
import decimal
import inspect
import warnings
class empty:
@ -71,22 +81,22 @@ class SkipField(Exception):
pass
NOT_READ_ONLY_WRITE_ONLY = 'May not set both `read_only` and `write_only`'
NOT_READ_ONLY_REQUIRED = 'May not set both `read_only` and `required`'
NOT_READ_ONLY_DEFAULT = 'May not set both `read_only` and `default`'
NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`'
MISSING_ERROR_MESSAGE = (
'ValidationError raised by `{class_name}`, but error key `{key}` does '
'not exist in the `error_messages` dictionary.'
)
class Field(object):
_creation_counter = 0
MESSAGES = {
'required': 'This field is required.'
default_error_messages = {
'required': _('This field is required.')
}
_NOT_READ_ONLY_WRITE_ONLY = 'May not set both `read_only` and `write_only`'
_NOT_READ_ONLY_REQUIRED = 'May not set both `read_only` and `required`'
_NOT_READ_ONLY_DEFAULT = 'May not set both `read_only` and `default`'
_NOT_REQUIRED_DEFAULT = 'May not set both `required` and `default`'
_MISSING_ERROR_MESSAGE = (
'ValidationError raised by `{class_name}`, but error key `{key}` does '
'not exist in the `MESSAGES` dictionary.'
)
default_validators = []
def __init__(self, read_only=False, write_only=False,
@ -100,10 +110,10 @@ class Field(object):
required = default is empty and not read_only
# Some combinations of keyword arguments do not make sense.
assert not (read_only and write_only), self._NOT_READ_ONLY_WRITE_ONLY
assert not (read_only and required), self._NOT_READ_ONLY_REQUIRED
assert not (read_only and default is not empty), self._NOT_READ_ONLY_DEFAULT
assert not (required and default is not empty), self._NOT_REQUIRED_DEFAULT
assert not (read_only and write_only), NOT_READ_ONLY_WRITE_ONLY
assert not (read_only and required), NOT_READ_ONLY_REQUIRED
assert not (read_only and default is not empty), NOT_READ_ONLY_DEFAULT
assert not (required and default is not empty), NOT_REQUIRED_DEFAULT
self.read_only = read_only
self.write_only = write_only
@ -113,7 +123,14 @@ class Field(object):
self.initial = initial
self.label = label
self.style = {} if style is None else style
self.validators = self.default_validators + validators
self.validators = validators or self.default_validators[:]
# Collect default error message from self and parent classes
messages = {}
for cls in reversed(self.__class__.__mro__):
messages.update(getattr(cls, 'default_error_messages', {}))
messages.update(error_messages or {})
self.error_messages = messages
def bind(self, field_name, parent, root):
"""
@ -186,12 +203,14 @@ class Field(object):
self.fail('required')
return self.get_default()
self.run_validators(data)
return self.to_native(data)
value = self.to_native(data)
self.run_validators(value)
return value
def run_validators(self, value):
if value in validators.EMPTY_VALUES:
return
errors = []
for validator in self.validators:
try:
@ -218,33 +237,32 @@ class Field(object):
A helper method that simply raises a validation error.
"""
try:
raise ValidationError(self.MESSAGES[key].format(**kwargs))
msg = self.error_messages[key]
except KeyError:
class_name = self.__class__.__name__
msg = self._MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key)
msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key)
raise AssertionError(msg)
raise ValidationError(msg.format(**kwargs))
def __new__(cls, *args, **kwargs):
"""
When a field is instantiated, we store the arguments that were used,
so that we can present a helpful representation of the object.
"""
instance = super(Field, cls).__new__(cls)
instance._args = args
instance._kwargs = kwargs
return instance
def __repr__(self):
arg_string = ', '.join([repr(val) for val in self._args])
kwarg_string = ', '.join([
'%s=%s' % (key, repr(val)) for key, val in self._kwargs.items()
])
if arg_string and kwarg_string:
arg_string += ', '
class_name = self.__class__.__name__
return "%s(%s%s)" % (class_name, arg_string, kwarg_string)
return representation.field_repr(self)
# Boolean types...
class BooleanField(Field):
MESSAGES = {
'required': 'This field is required.',
'invalid_value': '`{input}` is not a valid boolean.'
default_error_messages = {
'invalid': _('`{input}` is not a valid boolean.')
}
TRUE_VALUES = {'t', 'T', 'true', 'True', 'TRUE', '1', 1, True}
FALSE_VALUES = {'f', 'F', 'false', 'False', 'FALSE', '0', 0, 0.0, False}
@ -261,13 +279,23 @@ class BooleanField(Field):
return True
elif data in self.FALSE_VALUES:
return False
self.fail('invalid_value', input=data)
self.fail('invalid', input=data)
def to_primative(self, value):
if value is None:
return None
if value in self.TRUE_VALUES:
return True
elif value in self.FALSE_VALUES:
return False
return bool(value)
# String types...
class CharField(Field):
MESSAGES = {
'required': 'This field is required.',
'blank': 'This field may not be blank.'
default_error_messages = {
'blank': _('This field may not be blank.')
}
def __init__(self, **kwargs):
@ -281,19 +309,364 @@ class CharField(Field):
self.fail('blank')
return str(data)
def to_primative(self, value):
if value is None:
return None
return str(value)
class ChoiceField(Field):
MESSAGES = {
'required': 'This field is required.',
'invalid_choice': '`{input}` is not a valid choice.'
class EmailField(CharField):
default_error_messages = {
'invalid': _('Enter a valid email address.')
}
default_validators = [validators.validate_email]
def to_native(self, data):
ret = super(EmailField, self).to_native(data)
if ret is None:
return None
return ret.strip()
def to_primative(self, value):
ret = super(EmailField, self).to_primative(value)
if ret is None:
return None
return ret.strip()
class RegexField(CharField):
def __init__(self, regex, **kwargs):
kwargs['validators'] = (
[validators.RegexValidator(regex)] +
kwargs.get('validators', [])
)
super(RegexField, self).__init__(**kwargs)
class SlugField(CharField):
default_error_messages = {
'invalid': _("Enter a valid 'slug' consisting of letters, numbers, underscores or hyphens.")
}
default_validators = [validators.validate_slug]
class URLField(CharField):
default_error_messages = {
'invalid': _("Enter a valid URL.")
}
default_validators = [validators.URLValidator()]
# Number types...
class IntegerField(Field):
default_error_messages = {
'invalid': _('A valid integer is required.')
}
coerce_to_type = str
def __init__(self, **kwargs):
choices = kwargs.pop('choices')
max_value = kwargs.pop('max_value', None)
min_value = kwargs.pop('min_value', None)
super(IntegerField, self).__init__(**kwargs)
if max_value is not None:
self.validators.append(validators.MaxValueValidator(max_value))
if min_value is not None:
self.validators.append(validators.MinValueValidator(min_value))
print self.__class__.__name__, self.validators
assert choices, '`choices` argument is required and may not be empty'
def to_native(self, data):
try:
data = int(str(data))
except (ValueError, TypeError):
self.fail('invalid')
return data
def to_primative(self, value):
if value is None:
return None
return int(value)
class FloatField(Field):
default_error_messages = {
'invalid': _("'%s' value must be a float."),
}
def __init__(self, **kwargs):
max_value = kwargs.pop('max_value', None)
min_value = kwargs.pop('min_value', None)
super(FloatField, self).__init__(**kwargs)
if max_value is not None:
self.validators.append(validators.MaxValueValidator(max_value))
if min_value is not None:
self.validators.append(validators.MinValueValidator(min_value))
def to_primative(self, value):
if value is None:
return None
try:
return float(value)
except (TypeError, ValueError):
self.fail('invalid', value=value)
def to_native(self, value):
if value is None:
return None
return float(value)
class DecimalField(Field):
default_error_messages = {
'invalid': _('Enter a number.'),
'max_value': _('Ensure this value is less than or equal to {max_value}.'),
'min_value': _('Ensure this value is greater than or equal to {min_value}.'),
'max_digits': _('Ensure that there are no more than {max_digits} digits in total.'),
'max_decimal_places': _('Ensure that there are no more than {max_decimal_places} decimal places.'),
'max_whole_digits': _('Ensure that there are no more than {max_whole_digits} digits before the decimal point.')
}
def __init__(self, max_value=None, min_value=None, max_digits=None, decimal_places=None, **kwargs):
self.max_value, self.min_value = max_value, min_value
self.max_digits, self.max_decimal_places = max_digits, decimal_places
super(DecimalField, self).__init__(**kwargs)
if max_value is not None:
self.validators.append(validators.MaxValueValidator(max_value))
if min_value is not None:
self.validators.append(validators.MinValueValidator(min_value))
def from_native(self, value):
"""
Validates that the input is a decimal number. Returns a Decimal
instance. Returns None for empty values. Ensures that there are no more
than max_digits in the number, and no more than decimal_places digits
after the decimal point.
"""
if value in validators.EMPTY_VALUES:
return None
value = smart_text(value).strip()
try:
value = decimal.Decimal(value)
except decimal.DecimalException:
self.fail('invalid')
# Check for NaN. It is the only value that isn't equal to itself,
# so we can use this to identify NaN values.
if value != value:
self.fail('invalid')
# Check for infinity and negative infinity.
if value in (decimal.Decimal('Inf'), decimal.Decimal('-Inf')):
self.fail('invalid')
sign, digittuple, exponent = value.as_tuple()
decimals = abs(exponent)
# digittuple doesn't include any leading zeros.
digits = len(digittuple)
if decimals > digits:
# We have leading zeros up to or past the decimal point. Count
# everything past the decimal point as a digit. We do not count
# 0 before the decimal point as a digit since that would mean
# we would not allow max_digits = decimal_places.
digits = decimals
whole_digits = digits - decimals
if self.max_digits is not None and digits > self.max_digits:
self.fail('max_digits', max_digits=self.max_digits)
if self.decimal_places is not None and decimals > self.decimal_places:
self.fail('max_decimal_places', max_decimal_places=self.max_decimal_places)
if self.max_digits is not None and self.decimal_places is not None and whole_digits > (self.max_digits - self.decimal_places):
self.fail('max_whole_digits', max_while_digits=self.max_digits - self.decimal_places)
return value
# Date & time fields...
class DateField(Field):
default_error_messages = {
'invalid': _("Date has wrong format. Use one of these formats instead: %s"),
}
input_formats = api_settings.DATE_INPUT_FORMATS
format = api_settings.DATE_FORMAT
def __init__(self, input_formats=None, format=None, *args, **kwargs):
self.input_formats = input_formats if input_formats is not None else self.input_formats
self.format = format if format is not None else self.format
super(DateField, self).__init__(*args, **kwargs)
def from_native(self, value):
if value in validators.EMPTY_VALUES:
return None
if isinstance(value, datetime.datetime):
if timezone and settings.USE_TZ and timezone.is_aware(value):
# Convert aware datetimes to the default time zone
# before casting them to dates (#17742).
default_timezone = timezone.get_default_timezone()
value = timezone.make_naive(value, default_timezone)
return value.date()
if isinstance(value, datetime.date):
return value
for format in self.input_formats:
if format.lower() == ISO_8601:
try:
parsed = parse_date(value)
except (ValueError, TypeError):
pass
else:
if parsed is not None:
return parsed
else:
try:
parsed = datetime.datetime.strptime(value, format)
except (ValueError, TypeError):
pass
else:
return parsed.date()
humanized_format = humanize_datetime.date_formats(self.input_formats)
msg = self.error_messages['invalid'] % humanized_format
raise ValidationError(msg)
def to_primative(self, value):
if value is None or self.format is None:
return value
if isinstance(value, datetime.datetime):
value = value.date()
if self.format.lower() == ISO_8601:
return value.isoformat()
return value.strftime(self.format)
class DateTimeField(Field):
default_error_messages = {
'invalid': _("Datetime has wrong format. Use one of these formats instead: %s"),
}
input_formats = api_settings.DATETIME_INPUT_FORMATS
format = api_settings.DATETIME_FORMAT
def __init__(self, input_formats=None, format=None, *args, **kwargs):
self.input_formats = input_formats if input_formats is not None else self.input_formats
self.format = format if format is not None else self.format
super(DateTimeField, self).__init__(*args, **kwargs)
def from_native(self, value):
if value in validators.EMPTY_VALUES:
return None
if isinstance(value, datetime.datetime):
return value
if isinstance(value, datetime.date):
value = datetime.datetime(value.year, value.month, value.day)
if settings.USE_TZ:
# For backwards compatibility, interpret naive datetimes in
# local time. This won't work during DST change, but we can't
# do much about it, so we let the exceptions percolate up the
# call stack.
warnings.warn("DateTimeField received a naive datetime (%s)"
" while time zone support is active." % value,
RuntimeWarning)
default_timezone = timezone.get_default_timezone()
value = timezone.make_aware(value, default_timezone)
return value
for format in self.input_formats:
if format.lower() == ISO_8601:
try:
parsed = parse_datetime(value)
except (ValueError, TypeError):
pass
else:
if parsed is not None:
return parsed
else:
try:
parsed = datetime.datetime.strptime(value, format)
except (ValueError, TypeError):
pass
else:
return parsed
humanized_format = humanize_datetime.datetime_formats(self.input_formats)
msg = self.error_messages['invalid'] % humanized_format
raise ValidationError(msg)
def to_primative(self, value):
if value is None or self.format is None:
return value
if self.format.lower() == ISO_8601:
ret = value.isoformat()
if ret.endswith('+00:00'):
ret = ret[:-6] + 'Z'
return ret
return value.strftime(self.format)
class TimeField(Field):
default_error_messages = {
'invalid': _("Time has wrong format. Use one of these formats instead: %s"),
}
input_formats = api_settings.TIME_INPUT_FORMATS
format = api_settings.TIME_FORMAT
def __init__(self, input_formats=None, format=None, *args, **kwargs):
self.input_formats = input_formats if input_formats is not None else self.input_formats
self.format = format if format is not None else self.format
super(TimeField, self).__init__(*args, **kwargs)
def from_native(self, value):
if value in validators.EMPTY_VALUES:
return None
if isinstance(value, datetime.time):
return value
for format in self.input_formats:
if format.lower() == ISO_8601:
try:
parsed = parse_time(value)
except (ValueError, TypeError):
pass
else:
if parsed is not None:
return parsed
else:
try:
parsed = datetime.datetime.strptime(value, format)
except (ValueError, TypeError):
pass
else:
return parsed.time()
humanized_format = humanize_datetime.time_formats(self.input_formats)
msg = self.error_messages['invalid'] % humanized_format
raise ValidationError(msg)
def to_primative(self, value):
if value is None or self.format is None:
return value
if isinstance(value, datetime.datetime):
value = value.time()
if self.format.lower() == ISO_8601:
return value.isoformat()
return value.strftime(self.format)
# Choice types...
class ChoiceField(Field):
default_error_messages = {
'invalid_choice': _('`{input}` is not a valid choice.')
}
def __init__(self, choices, **kwargs):
# Allow either single or paired choices style:
# choices = [1, 2, 3]
# choices = [(1, 'First'), (2, 'Second'), (3, 'Third')]
@ -321,12 +694,14 @@ class ChoiceField(Field):
except KeyError:
self.fail('invalid_choice', input=data)
def to_primative(self, value):
return value
class MultipleChoiceField(ChoiceField):
MESSAGES = {
'required': 'This field is required.',
'invalid_choice': '`{input}` is not a valid choice.',
'not_a_list': 'Expected a list of items but got type `{input_type}`'
default_error_messages = {
'invalid_choice': _('`{input}` is not a valid choice.'),
'not_a_list': _('Expected a list of items but got type `{input_type}`')
}
def to_native(self, data):
@ -337,72 +712,42 @@ class MultipleChoiceField(ChoiceField):
for item in data
])
class IntegerField(Field):
MESSAGES = {
'required': 'This field is required.',
'invalid_integer': 'A valid integer is required.'
}
def __init__(self, **kwargs):
max_value = kwargs.pop('max_value', None)
min_value = kwargs.pop('min_value', None)
super(IntegerField, self).__init__(**kwargs)
if max_value is not None:
self.validators.append(validators.MaxValueValidator(max_value))
if min_value is not None:
self.validators.append(validators.MinValueValidator(min_value))
def to_native(self, data):
try:
data = int(str(data))
except (ValueError, TypeError):
self.fail('invalid_integer')
return data
def to_primative(self, value):
if value is None:
return None
return int(value)
return value
class EmailField(CharField):
pass # TODO
class URLField(CharField):
pass # TODO
class RegexField(CharField):
def __init__(self, **kwargs):
self.regex = kwargs.pop('regex')
super(CharField, self).__init__(**kwargs)
class DateField(CharField):
def __init__(self, **kwargs):
self.input_formats = kwargs.pop('input_formats', None)
super(DateField, self).__init__(**kwargs)
class TimeField(CharField):
def __init__(self, **kwargs):
self.input_formats = kwargs.pop('input_formats', None)
super(TimeField, self).__init__(**kwargs)
class DateTimeField(CharField):
def __init__(self, **kwargs):
self.input_formats = kwargs.pop('input_formats', None)
super(DateTimeField, self).__init__(**kwargs)
# File types...
class FileField(Field):
pass # TODO
class ImageField(Field):
pass # TODO
# Advanced field types...
class ReadOnlyField(Field):
"""
A read-only field that simply returns the field value.
If the field is a method with no parameters, the method will be called
and it's return value used as the representation.
For example, the following would call `get_expiry_date()` on the object:
class ExampleSerializer(self):
expiry_date = ReadOnlyField(source='get_expiry_date')
"""
def __init__(self, **kwargs):
kwargs['read_only'] = True
super(ReadOnlyField, self).__init__(**kwargs)
def to_native(self, data):
raise NotImplemented('.to_native() not supported.')
def to_primative(self, value):
if is_simple_callable(value):
return value()
@ -410,11 +755,28 @@ class ReadOnlyField(Field):
class MethodField(Field):
"""
A read-only field that get its representation from calling a method on the
parent serializer class. The method called will be of the form
"get_{field_name}", and should take a single argument, which is the
object being serialized.
For example:
class ExampleSerializer(self):
extra_info = MethodField()
def get_extra_info(self, obj):
return ... # Calculate some data to return.
"""
def __init__(self, **kwargs):
kwargs['source'] = '*'
kwargs['read_only'] = True
super(MethodField, self).__init__(**kwargs)
def to_native(self, data):
raise NotImplemented('.to_native() not supported.')
def to_primative(self, value):
attr = 'get_{field_name}'.format(field_name=self.field_name)
method = getattr(self.parent, attr)
@ -424,35 +786,14 @@ class MethodField(Field):
class ModelField(Field):
"""
A generic field that can be used against an arbitrary model field.
This is used by `ModelSerializer` when dealing with custom model fields,
that do not have a serializer field to be mapped to.
"""
def __init__(self, *args, **kwargs):
try:
self.model_field = kwargs.pop('model_field')
except KeyError:
raise ValueError("ModelField requires 'model_field' kwarg")
self.min_length = kwargs.pop('min_length',
getattr(self.model_field, 'min_length', None))
self.max_length = kwargs.pop('max_length',
getattr(self.model_field, 'max_length', None))
self.min_value = kwargs.pop('min_value',
getattr(self.model_field, 'min_value', None))
self.max_value = kwargs.pop('max_value',
getattr(self.model_field, 'max_value', None))
super(ModelField, self).__init__(*args, **kwargs)
if self.min_length is not None:
self.validators.append(validators.MinLengthValidator(self.min_length))
if self.max_length is not None:
self.validators.append(validators.MaxLengthValidator(self.max_length))
if self.min_value is not None:
self.validators.append(validators.MinValueValidator(self.min_value))
if self.max_value is not None:
self.validators.append(validators.MaxValueValidator(self.max_value))
def get_attribute(self, instance):
return get_attribute(instance, self.source_attrs[:-1])
def __init__(self, model_field, **kwargs):
self.model_field = model_field
kwargs['source'] = '*'
super(ModelField, self).__init__(**kwargs)
def to_native(self, data):
rel = getattr(self.model_field, 'rel', None)

View File

@ -10,15 +10,15 @@ python primitives.
2. The process of marshalling between python primitives and request and
response content is handled by parsers and renderers.
"""
from django.core import validators
from django.core.exceptions import ValidationError
from django.db import models
from django.utils import six
from collections import namedtuple, OrderedDict
from rest_framework.fields import empty, set_value, Field, SkipField
from rest_framework.settings import api_settings
from rest_framework.utils import html
from rest_framework.utils import html, modelinfo, representation
import copy
import inspect
# Note: We do the following so that users of the framework can use this style:
#
@ -146,12 +146,10 @@ class SerializerMetaclass(type):
class Serializer(BaseSerializer):
def __new__(cls, *args, **kwargs):
many = kwargs.pop('many', False)
if many:
class DynamicListSerializer(ListSerializer):
child = cls()
return DynamicListSerializer(*args, **kwargs)
return super(Serializer, cls).__new__(cls)
if kwargs.pop('many', False):
kwargs['child'] = cls()
return ListSerializer(*args, **kwargs)
return super(Serializer, cls).__new__(cls, *args, **kwargs)
def __init__(self, *args, **kwargs):
self.context = kwargs.pop('context', {})
@ -248,6 +246,9 @@ class Serializer(BaseSerializer):
error = errors.get(field.field_name)
yield FieldResult(field, value, error)
def __repr__(self):
return representation.serializer_repr(self, indent=1)
class ListSerializer(BaseSerializer):
child = None
@ -299,26 +300,8 @@ class ListSerializer(BaseSerializer):
self.instance = self.create(self.validated_data)
return self.instance
def _resolve_model(obj):
"""
Resolve supplied `obj` to a Django model class.
`obj` must be a Django model class itself, or a string
representation of one. Useful in situtations like GH #1225 where
Django may not have resolved a string-based reference to a model in
another model's foreign key definition.
String representations should have the format:
'appname.ModelName'
"""
if isinstance(obj, six.string_types) and len(obj.split('.')) == 2:
app_name, model_name = obj.split('.')
return models.get_model(app_name, model_name)
elif inspect.isclass(obj) and issubclass(obj, models.Model):
return obj
else:
raise ValueError("{0} is not a Django model".format(obj))
def __repr__(self):
return representation.list_repr(self, indent=1)
class ModelSerializerOptions(object):
@ -334,24 +317,25 @@ class ModelSerializerOptions(object):
class ModelSerializer(Serializer):
field_mapping = {
models.AutoField: IntegerField,
# models.FloatField: FloatField,
models.IntegerField: IntegerField,
models.PositiveIntegerField: IntegerField,
models.SmallIntegerField: IntegerField,
models.PositiveSmallIntegerField: IntegerField,
models.DateTimeField: DateTimeField,
models.DateField: DateField,
models.TimeField: TimeField,
# models.DecimalField: DecimalField,
models.EmailField: EmailField,
models.CharField: CharField,
models.URLField: URLField,
# models.SlugField: SlugField,
models.TextField: CharField,
models.CommaSeparatedIntegerField: CharField,
models.BigIntegerField: IntegerField,
models.BooleanField: BooleanField,
models.NullBooleanField: BooleanField,
models.CharField: CharField,
models.CommaSeparatedIntegerField: CharField,
models.DateField: DateField,
models.DateTimeField: DateTimeField,
models.DecimalField: DecimalField,
models.EmailField: EmailField,
models.FileField: FileField,
models.FloatField: FloatField,
models.IntegerField: IntegerField,
models.NullBooleanField: BooleanField,
models.PositiveIntegerField: IntegerField,
models.PositiveSmallIntegerField: IntegerField,
models.SlugField: SlugField,
models.SmallIntegerField: IntegerField,
models.TextField: CharField,
models.TimeField: TimeField,
models.URLField: URLField,
# models.ImageField: ImageField,
}
@ -392,85 +376,31 @@ class ModelSerializer(Serializer):
"""
Return all the fields that should be serialized for the model.
"""
cls = self.opts.model
opts = cls._meta.concrete_model._meta
info = modelinfo.get_field_info(self.opts.model)
ret = OrderedDict()
nested = bool(self.opts.depth)
# Deal with adding the primary key field
pk_field = opts.pk
while pk_field.rel and pk_field.rel.parent_link:
# If model is a child via multitable inheritance, use parent's pk
pk_field = pk_field.rel.to._meta.pk
serializer_pk_field = self.get_pk_field(pk_field)
serializer_pk_field = self.get_pk_field(info.pk)
if serializer_pk_field:
ret[pk_field.name] = serializer_pk_field
ret[info.pk.name] = serializer_pk_field
# Deal with forward relationships
forward_rels = [field for field in opts.fields if field.serialize]
forward_rels += [field for field in opts.many_to_many if field.serialize]
# Regular fields
for field_name, field in info.fields.items():
ret[field_name] = self.get_field(field)
for model_field in forward_rels:
has_through_model = False
if model_field.rel:
to_many = isinstance(model_field,
models.fields.related.ManyToManyField)
related_model = _resolve_model(model_field.rel.to)
if to_many and not model_field.rel.through._meta.auto_created:
has_through_model = True
if model_field.rel and nested:
field = self.get_nested_field(model_field, related_model, to_many)
elif model_field.rel:
field = self.get_related_field(model_field, related_model, to_many)
# Forward relations
for field_name, relation_info in info.forward_relations.items():
if self.opts.depth:
ret[field_name] = self.get_nested_field(*relation_info)
else:
field = self.get_field(model_field)
ret[field_name] = self.get_related_field(*relation_info)
if field:
if has_through_model:
field.read_only = True
ret[model_field.name] = field
# Deal with reverse relationships
if not self.opts.fields:
reverse_rels = []
else:
# Reverse relationships are only included if they are explicitly
# present in the `fields` option on the serializer
reverse_rels = opts.get_all_related_objects()
reverse_rels += opts.get_all_related_many_to_many_objects()
for relation in reverse_rels:
accessor_name = relation.get_accessor_name()
if not self.opts.fields or accessor_name not in self.opts.fields:
continue
related_model = relation.model
to_many = relation.field.rel.multiple
has_through_model = False
is_m2m = isinstance(relation.field,
models.fields.related.ManyToManyField)
if (
is_m2m and
hasattr(relation.field.rel, 'through') and
not relation.field.rel.through._meta.auto_created
):
has_through_model = True
if nested:
field = self.get_nested_field(None, related_model, to_many)
else:
field = self.get_related_field(None, related_model, to_many)
if field:
if has_through_model:
field.read_only = True
ret[accessor_name] = field
# Reverse relations
for accessor_name, relation_info in info.reverse_relations.items():
if accessor_name in self.opts.fields:
if self.opts.depth:
ret[field_name] = self.get_nested_field(*relation_info)
else:
ret[field_name] = self.get_related_field(*relation_info)
return ret
@ -480,7 +410,7 @@ class ModelSerializer(Serializer):
"""
return self.get_field(model_field)
def get_nested_field(self, model_field, related_model, to_many):
def get_nested_field(self, model_field, related_model, to_many, has_through_model):
"""
Creates a default instance of a nested relational field.
@ -491,59 +421,148 @@ class ModelSerializer(Serializer):
model = related_model
depth = self.opts.depth - 1
return NestedModelSerializer(many=to_many)
kwargs = {'read_only': True}
if to_many:
kwargs['many'] = True
return NestedModelSerializer(**kwargs)
def get_related_field(self, model_field, related_model, to_many):
def get_related_field(self, model_field, related_model, to_many, has_through_model):
"""
Creates a default instance of a flat relational field.
Note that model_field will be `None` for reverse relationships.
"""
# TODO: filter queryset using:
# .using(db).complex_filter(self.rel.limit_choices_to)
kwargs = {
'queryset': related_model._default_manager,
}
kwargs = {}
# 'queryset': related_model._default_manager,
# 'many': to_many
# }
if to_many:
kwargs['many'] = True
if has_through_model:
kwargs['read_only'] = True
kwargs.pop('queryset', None)
if model_field:
kwargs['required'] = not(model_field.null or model_field.blank)
if model_field.null or model_field.blank:
kwargs['required'] = False
# if model_field.help_text is not None:
# kwargs['help_text'] = model_field.help_text
if model_field.verbose_name is not None:
kwargs['label'] = model_field.verbose_name
if not model_field.editable:
kwargs['read_only'] = True
if model_field.verbose_name is not None:
kwargs['label'] = model_field.verbose_name
kwargs.pop('queryset', None)
return IntegerField(**kwargs)
# TODO: return PrimaryKeyRelatedField(**kwargs)
return PrimaryKeyRelatedField(**kwargs)
def get_field(self, model_field):
"""
Creates a default instance of a basic non-relational field.
"""
kwargs = {}
validator_kwarg = model_field.validators
if model_field.null or model_field.blank:
kwargs['required'] = False
if isinstance(model_field, models.AutoField) or not model_field.editable:
kwargs['read_only'] = True
if model_field.has_default():
kwargs['default'] = model_field.get_default()
if issubclass(model_field.__class__, models.TextField):
kwargs['widget'] = widgets.Textarea
if model_field.verbose_name is not None:
kwargs['label'] = model_field.verbose_name
if model_field.validators is not None:
kwargs['validators'] = model_field.validators
if isinstance(model_field, models.AutoField) or not model_field.editable:
kwargs['read_only'] = True
# Read only implies that the field is not required.
# We have a cleaner repr on the instance if we don't set it.
kwargs.pop('required', None)
if model_field.has_default():
kwargs['default'] = model_field.get_default()
# Having a default implies that the field is not required.
# We have a cleaner repr on the instance if we don't set it.
kwargs.pop('required', None)
# Ensure that max_length is passed explicitly as a keyword arg,
# rather than as a validator.
max_length = getattr(model_field, 'max_length', None)
if max_length is not None:
kwargs['max_length'] = max_length
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.MaxLengthValidator)
]
# Ensure that min_length is passed explicitly as a keyword arg,
# rather than as a validator.
min_length = getattr(model_field, 'min_length', None)
if min_length is not None:
kwargs['min_length'] = min_length
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.MinLengthValidator)
]
# Ensure that max_value is passed explicitly as a keyword arg,
# rather than as a validator.
max_value = next((
validator.limit_value for validator in validator_kwarg
if isinstance(validator, validators.MaxValueValidator)
), None)
if max_value is not None:
kwargs['max_value'] = max_value
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.MaxValueValidator)
]
# Ensure that max_value is passed explicitly as a keyword arg,
# rather than as a validator.
min_value = next((
validator.limit_value for validator in validator_kwarg
if isinstance(validator, validators.MinValueValidator)
), None)
if min_value is not None:
kwargs['min_value'] = min_value
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.MinValueValidator)
]
# URLField does not need to include the URLValidator argument,
# as it is explicitly added in.
if isinstance(model_field, models.URLField):
validator_kwarg = [
validator for validator in validator_kwarg
if not isinstance(validator, validators.URLValidator)
]
# EmailField does not need to include the validate_email argument,
# as it is explicitly added in.
if isinstance(model_field, models.EmailField):
validator_kwarg = [
validator for validator in validator_kwarg
if validator is not validators.validate_email
]
# SlugField do not need to include the 'validate_slug' argument,
if isinstance(model_field, models.SlugField):
validator_kwarg = [
validator for validator in validator_kwarg
if validator is not validators.validate_slug
]
max_digits = getattr(model_field, 'max_digits', None)
if max_digits is not None:
kwargs['max_digits'] = max_digits
decimal_places = getattr(model_field, 'decimal_places', None)
if decimal_places is not None:
kwargs['decimal_places'] = decimal_places
if validator_kwarg:
kwargs['validators'] = validator_kwarg
# if issubclass(model_field.__class__, models.TextField):
# kwargs['widget'] = widgets.Textarea
# if model_field.help_text is not None:
# kwargs['help_text'] = model_field.help_text
@ -555,31 +574,10 @@ class ModelSerializer(Serializer):
kwargs['empty'] = None
return ChoiceField(**kwargs)
# put this below the ChoiceField because min_value isn't a valid initializer
if issubclass(model_field.__class__, models.PositiveIntegerField) or \
issubclass(model_field.__class__, models.PositiveSmallIntegerField):
kwargs['min_value'] = 0
if model_field.null and \
issubclass(model_field.__class__, (models.CharField, models.TextField)):
kwargs['allow_none'] = True
# attribute_dict = {
# models.CharField: ['max_length'],
# models.CommaSeparatedIntegerField: ['max_length'],
# models.DecimalField: ['max_digits', 'decimal_places'],
# models.EmailField: ['max_length'],
# models.FileField: ['max_length'],
# models.ImageField: ['max_length'],
# models.SlugField: ['max_length'],
# models.URLField: ['max_length'],
# }
# if model_field.__class__ in attribute_dict:
# attributes = attribute_dict[model_field.__class__]
# for attribute in attributes:
# kwargs.update({attribute: getattr(model_field, attribute)})
try:
return self.field_mapping[model_field.__class__](**kwargs)
except KeyError:
@ -594,28 +592,21 @@ class HyperlinkedModelSerializerOptions(ModelSerializerOptions):
super(HyperlinkedModelSerializerOptions, self).__init__(meta)
self.view_name = getattr(meta, 'view_name', None)
self.lookup_field = getattr(meta, 'lookup_field', None)
self.url_field_name = getattr(meta, 'url_field_name', api_settings.URL_FIELD_NAME)
class HyperlinkedModelSerializer(ModelSerializer):
_options_class = HyperlinkedModelSerializerOptions
_default_view_name = '%(model_name)s-detail'
_hyperlink_field_class = HyperlinkedRelatedField
_hyperlink_identify_field_class = HyperlinkedIdentityField
def get_default_fields(self):
fields = super(HyperlinkedModelSerializer, self).get_default_fields()
if self.opts.view_name is None:
self.opts.view_name = self._get_default_view_name(self.opts.model)
self.opts.view_name = self.get_default_view_name(self.opts.model)
if self.opts.url_field_name not in fields:
url_field = self._hyperlink_identify_field_class(
view_name=self.opts.view_name,
lookup_field=self.opts.lookup_field
)
url_field_name = api_settings.URL_FIELD_NAME
if url_field_name not in fields:
ret = fields.__class__()
ret[self.opts.url_field_name] = url_field
ret[url_field_name] = self.get_url_field()
ret.update(fields)
fields = ret
@ -625,39 +616,48 @@ class HyperlinkedModelSerializer(ModelSerializer):
if self.opts.fields and model_field.name in self.opts.fields:
return self.get_field(model_field)
def get_related_field(self, model_field, related_model, to_many):
def get_url_field(self):
kwargs = {
'view_name': self.get_default_view_name(self.opts.model)
}
if self.opts.lookup_field:
kwargs['lookup_field'] = self.opts.lookup_field
return HyperlinkedIdentityField(**kwargs)
def get_related_field(self, model_field, related_model, to_many, has_through_model):
"""
Creates a default instance of a flat relational field.
"""
# TODO: filter queryset using:
# .using(db).complex_filter(self.rel.limit_choices_to)
# kwargs = {
# 'queryset': related_model._default_manager,
# 'view_name': self._get_default_view_name(related_model),
# 'many': to_many
# }
kwargs = {}
kwargs = {
'queryset': related_model._default_manager,
'view_name': self.get_default_view_name(related_model),
}
if to_many:
kwargs['many'] = True
if has_through_model:
kwargs['read_only'] = True
kwargs.pop('queryset', None)
if model_field:
kwargs['required'] = not(model_field.null or model_field.blank)
if model_field.null or model_field.blank:
kwargs['required'] = False
# if model_field.help_text is not None:
# kwargs['help_text'] = model_field.help_text
if model_field.verbose_name is not None:
kwargs['label'] = model_field.verbose_name
if not model_field.editable:
kwargs['read_only'] = True
kwargs.pop('queryset', None)
return IntegerField(**kwargs)
# if self.opts.lookup_field:
# kwargs['lookup_field'] = self.opts.lookup_field
return HyperlinkedRelatedField(**kwargs)
# return self._hyperlink_field_class(**kwargs)
def _get_default_view_name(self, model):
def get_default_view_name(self, model):
"""
Return the view name to use if 'view_name' is not specified in 'Meta'
Return the view name to use for related models.
"""
model_meta = model._meta
format_kwargs = {
'app_label': model_meta.app_label,
'model_name': model_meta.object_name.lower()
return '%(model_name)s-detail' % {
'app_label': model._meta.app_label,
'model_name': model._meta.object_name.lower()
}
return self._default_view_name % format_kwargs

View File

@ -0,0 +1,47 @@
"""
Helper functions that convert strftime formats into more readable representations.
"""
from rest_framework import ISO_8601
def datetime_formats(formats):
format = ', '.join(formats).replace(
ISO_8601,
'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]'
)
return humanize_strptime(format)
def date_formats(formats):
format = ', '.join(formats).replace(ISO_8601, 'YYYY[-MM[-DD]]')
return humanize_strptime(format)
def time_formats(formats):
format = ', '.join(formats).replace(ISO_8601, 'hh:mm[:ss[.uuuuuu]]')
return humanize_strptime(format)
def humanize_strptime(format_string):
# Note that we're missing some of the locale specific mappings that
# don't really make sense.
mapping = {
"%Y": "YYYY",
"%y": "YY",
"%m": "MM",
"%b": "[Jan-Dec]",
"%B": "[January-December]",
"%d": "DD",
"%H": "hh",
"%I": "hh", # Requires '%p' to differentiate from '%H'.
"%M": "mm",
"%S": "ss",
"%f": "uuuuuu",
"%a": "[Mon-Sun]",
"%A": "[Monday-Sunday]",
"%p": "[AM|PM]",
"%z": "[+HHMM|-HHMM]"
}
for key, val in mapping.items():
format_string = format_string.replace(key, val)
return format_string

View File

@ -0,0 +1,97 @@
"""
Helper functions for returning the field information that is associated
with a model class.
"""
from collections import namedtuple, OrderedDict
from django.db import models
from django.utils import six
import inspect
FieldInfo = namedtuple('FieldResult', ['pk', 'fields', 'forward_relations', 'reverse_relations'])
RelationInfo = namedtuple('RelationInfo', ['field', 'related', 'to_many', 'has_through_model'])
def _resolve_model(obj):
"""
Resolve supplied `obj` to a Django model class.
`obj` must be a Django model class itself, or a string
representation of one. Useful in situtations like GH #1225 where
Django may not have resolved a string-based reference to a model in
another model's foreign key definition.
String representations should have the format:
'appname.ModelName'
"""
if isinstance(obj, six.string_types) and len(obj.split('.')) == 2:
app_name, model_name = obj.split('.')
return models.get_model(app_name, model_name)
elif inspect.isclass(obj) and issubclass(obj, models.Model):
return obj
raise ValueError("{0} is not a Django model".format(obj))
def get_field_info(model):
"""
Given a model class, returns a `FieldInfo` instance containing metadata
about the various field types on the model.
"""
opts = model._meta.concrete_model._meta
# Deal with the primary key.
pk = opts.pk
while pk.rel and pk.rel.parent_link:
# If model is a child via multitable inheritance, use parent's pk.
pk = pk.rel.to._meta.pk
# Deal with regular fields.
fields = OrderedDict()
for field in [field for field in opts.fields if field.serialize and not field.rel]:
fields[field.name] = field
# Deal with forward relationships.
forward_relations = OrderedDict()
for field in [field for field in opts.fields if field.serialize and field.rel]:
forward_relations[field.name] = RelationInfo(
field=field,
related=_resolve_model(field.rel.to),
to_many=False,
has_through_model=False
)
# Deal with forward many-to-many relationships.
for field in [field for field in opts.many_to_many if field.serialize]:
forward_relations[field.name] = RelationInfo(
field=field,
related=_resolve_model(field.rel.to),
to_many=True,
has_through_model=(
not field.rel.through._meta.auto_created
)
)
# Deal with reverse relationships.
reverse_relations = OrderedDict()
for relation in opts.get_all_related_objects():
accessor_name = relation.get_accessor_name()
reverse_relations[accessor_name] = RelationInfo(
field=None,
related=relation.model,
to_many=relation.field.rel.multiple,
has_through_model=False
)
# Deal with reverse many-to-many relationships.
for relation in opts.get_all_related_many_to_many_objects():
accessor_name = relation.get_accessor_name()
reverse_relations[accessor_name] = RelationInfo(
field=None,
related=relation.model,
to_many=True,
has_through_model=(
hasattr(relation.field.rel, 'through') and
not relation.field.rel.through._meta.auto_created
)
)
return FieldInfo(pk, fields, forward_relations, reverse_relations)

View File

@ -0,0 +1,72 @@
"""
Helper functions for creating user-friendly representations
of serializer classes and serializer fields.
"""
import re
def smart_repr(value):
value = repr(value)
# Representations like u'help text'
# should simply be presented as 'help text'
if value.startswith("u'") and value.endswith("'"):
return value[1:]
# Representations like
# <django.core.validators.RegexValidator object at 0x1047af050>
# Should be presented as
# <django.core.validators.RegexValidator object>
value = re.sub(' at 0x[0-9a-f]{8,10}>', '>', value)
return value
def field_repr(field, force_many=False):
kwargs = field._kwargs
if force_many:
kwargs = kwargs.copy()
kwargs['many'] = True
kwargs.pop('child', None)
arg_string = ', '.join([smart_repr(val) for val in field._args])
kwarg_string = ', '.join([
'%s=%s' % (key, smart_repr(val))
for key, val in sorted(kwargs.items())
])
if arg_string and kwarg_string:
arg_string += ', '
if force_many:
class_name = force_many.__class__.__name__
else:
class_name = field.__class__.__name__
return "%s(%s%s)" % (class_name, arg_string, kwarg_string)
def serializer_repr(serializer, indent, force_many=None):
ret = field_repr(serializer, force_many) + ':'
indent_str = ' ' * indent
if force_many:
fields = force_many.fields
else:
fields = serializer.fields
for field_name, field in fields.items():
ret += '\n' + indent_str + field_name + ' = '
if hasattr(field, 'fields'):
ret += serializer_repr(field, indent + 1)
elif hasattr(field, 'child'):
ret += list_repr(field, indent + 1)
else:
ret += field_repr(field)
return ret
def list_repr(serializer, indent):
child = serializer.child
if hasattr(child, 'fields'):
return serializer_repr(serializer, indent, force_many=child)
return field_repr(serializer)

View File

@ -0,0 +1,160 @@
"""
The `ModelSerializer` and `HyperlinkedModelSerializer` classes are essentially
shortcuts for automatically creating serializers based on a given model class.
These tests deal with ensuring that we correctly map the model fields onto
an appropriate set of serializer fields for each case.
"""
from django.db import models
from django.test import TestCase
from rest_framework import serializers
# Models for testing regular field mapping
class RegularFieldsModel(models.Model):
auto_field = models.AutoField(primary_key=True)
big_integer_field = models.BigIntegerField()
boolean_field = models.BooleanField()
char_field = models.CharField(max_length=100)
comma_seperated_integer_field = models.CommaSeparatedIntegerField(max_length=100)
date_field = models.DateField()
datetime_field = models.DateTimeField()
decimal_field = models.DecimalField(max_digits=3, decimal_places=1)
email_field = models.EmailField(max_length=100)
float_field = models.FloatField()
integer_field = models.IntegerField()
null_boolean_field = models.NullBooleanField()
positive_integer_field = models.PositiveIntegerField()
positive_small_integer_field = models.PositiveSmallIntegerField()
slug_field = models.SlugField(max_length=100)
small_integer_field = models.SmallIntegerField()
text_field = models.TextField()
time_field = models.TimeField()
url_field = models.URLField(max_length=100)
REGULAR_FIELDS_REPR = """
TestSerializer():
auto_field = IntegerField(label='auto field', read_only=True)
big_integer_field = IntegerField(label='big integer field')
boolean_field = BooleanField(default=False, label='boolean field')
char_field = CharField(label='char field', max_length=100)
comma_seperated_integer_field = CharField(label='comma seperated integer field', max_length=100, validators=[<django.core.validators.RegexValidator object>])
date_field = DateField(label='date field')
datetime_field = DateTimeField(label='datetime field')
decimal_field = DecimalField(decimal_places=1, label='decimal field', max_digits=3)
email_field = EmailField(label='email field', max_length=100)
float_field = FloatField(label='float field')
integer_field = IntegerField(label='integer field')
null_boolean_field = BooleanField(label='null boolean field', required=False)
positive_integer_field = IntegerField(label='positive integer field')
positive_small_integer_field = IntegerField(label='positive small integer field')
slug_field = SlugField(label='slug field', max_length=100)
small_integer_field = IntegerField(label='small integer field')
text_field = CharField(label='text field')
time_field = TimeField(label='time field')
url_field = URLField(label='url field', max_length=100)
""".strip()
# Model for testing relational field mapping
class ForeignKeyTarget(models.Model):
char_field = models.CharField(max_length=100)
class ManyToManyTarget(models.Model):
char_field = models.CharField(max_length=100)
class OneToOneTarget(models.Model):
char_field = models.CharField(max_length=100)
class RelationalModel(models.Model):
foreign_key = models.ForeignKey(ForeignKeyTarget)
many_to_many = models.ManyToManyField(ManyToManyTarget)
one_to_one = models.OneToOneField(OneToOneTarget)
RELATIONAL_FLAT_REPR = """
TestSerializer():
id = IntegerField(label='ID', read_only=True)
foreign_key = PrimaryKeyRelatedField(label='foreign key', queryset=<django.db.models.manager.Manager object>)
one_to_one = PrimaryKeyRelatedField(label='one to one', queryset=<django.db.models.manager.Manager object>)
many_to_many = PrimaryKeyRelatedField(label='many to many', many=True, queryset=<django.db.models.manager.Manager object>)
""".strip()
RELATIONAL_NESTED_REPR = """
TestSerializer():
id = IntegerField(label='ID', read_only=True)
foreign_key = NestedModelSerializer(read_only=True):
id = IntegerField(label='ID', read_only=True)
name = CharField(label='name', max_length=100)
one_to_one = NestedModelSerializer(read_only=True):
id = IntegerField(label='ID', read_only=True)
name = CharField(label='name', max_length=100)
many_to_many = NestedModelSerializer(many=True, read_only=True):
id = IntegerField(label='ID', read_only=True)
name = CharField(label='name', max_length=100)
""".strip()
HYPERLINKED_FLAT_REPR = """
TestSerializer():
url = HyperlinkedIdentityField(view_name='relationalmodel-detail')
foreign_key = HyperlinkedRelatedField(label='foreign key', queryset=<django.db.models.manager.Manager object>, view_name='foreignkeytarget-detail')
one_to_one = HyperlinkedRelatedField(label='one to one', queryset=<django.db.models.manager.Manager object>, view_name='onetoonetarget-detail')
many_to_many = HyperlinkedRelatedField(label='many to many', many=True, queryset=<django.db.models.manager.Manager object>, view_name='manytomanytarget-detail')
""".strip()
HYPERLINKED_NESTED_REPR = """
TestSerializer():
url = HyperlinkedIdentityField(view_name='relationalmodel-detail')
foreign_key = NestedModelSerializer(read_only=True):
id = IntegerField(label='ID', read_only=True)
name = CharField(label='name', max_length=100)
one_to_one = NestedModelSerializer(read_only=True):
id = IntegerField(label='ID', read_only=True)
name = CharField(label='name', max_length=100)
many_to_many = NestedModelSerializer(many=True, read_only=True):
id = IntegerField(label='ID', read_only=True)
name = CharField(label='name', max_length=100)
""".strip()
class TestSerializerMappings(TestCase):
def test_regular_fields(self):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = RegularFieldsModel
self.assertEqual(repr(TestSerializer()), REGULAR_FIELDS_REPR)
def test_flat_relational_fields(self):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = RelationalModel
self.assertEqual(repr(TestSerializer()), RELATIONAL_FLAT_REPR)
def test_nested_relational_fields(self):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = RelationalModel
depth = 1
self.assertEqual(repr(TestSerializer()), RELATIONAL_NESTED_REPR)
def test_flat_hyperlinked_fields(self):
class TestSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = RelationalModel
self.assertEqual(repr(TestSerializer()), HYPERLINKED_FLAT_REPR)
def test_nested_hyperlinked_fields(self):
class TestSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = RelationalModel
depth = 1
self.assertEqual(repr(TestSerializer()), HYPERLINKED_NESTED_REPR)

View File

@ -1,6 +1,6 @@
from django.test import TestCase
from django.utils import six
from rest_framework.serializers import _resolve_model
from rest_framework.utils.modelinfo import _resolve_model
from tests.models import BasicModel

View File

@ -22,18 +22,18 @@
# https://github.com/tomchristie/django-rest-framework/issues/446
# """
# field = serializers.PrimaryKeyRelatedField(queryset=NullModel.objects.all())
# self.assertRaises(serializers.ValidationError, field.from_native, '')
# self.assertRaises(serializers.ValidationError, field.from_native, [])
# self.assertRaises(serializers.ValidationError, field.to_primative, '')
# self.assertRaises(serializers.ValidationError, field.to_primative, [])
# def test_hyperlinked_related_field_with_empty_string(self):
# field = serializers.HyperlinkedRelatedField(queryset=NullModel.objects.all(), view_name='')
# self.assertRaises(serializers.ValidationError, field.from_native, '')
# self.assertRaises(serializers.ValidationError, field.from_native, [])
# self.assertRaises(serializers.ValidationError, field.to_primative, '')
# self.assertRaises(serializers.ValidationError, field.to_primative, [])
# def test_slug_related_field_with_empty_string(self):
# field = serializers.SlugRelatedField(queryset=NullModel.objects.all(), slug_field='pk')
# self.assertRaises(serializers.ValidationError, field.from_native, '')
# self.assertRaises(serializers.ValidationError, field.from_native, [])
# self.assertRaises(serializers.ValidationError, field.to_primative, '')
# self.assertRaises(serializers.ValidationError, field.to_primative, [])
# class TestManyRelatedMixin(TestCase):

View File

@ -6,7 +6,7 @@
# def test_empty_serializer(self):
# class FooBarSerializer(serializers.Serializer):
# foo = serializers.IntegerField()
# bar = serializers.SerializerMethodField('get_bar')
# bar = serializers.MethodField()
# def get_bar(self, obj):
# return 'bar'