This commit is contained in:
Ross Patterson 2017-10-03 11:16:35 +00:00 committed by GitHub
commit bfe2359f21
6 changed files with 48 additions and 12 deletions

View File

@ -10,11 +10,13 @@ import uuid
from collections import OrderedDict from collections import OrderedDict
from django.conf import settings from django.conf import settings
from django.contrib.contenttypes import fields as ct_fields
from django.core.exceptions import ValidationError as DjangoValidationError from django.core.exceptions import ValidationError as DjangoValidationError
from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import FieldDoesNotExist, ObjectDoesNotExist
from django.core.validators import ( from django.core.validators import (
EmailValidator, RegexValidator, URLValidator, ip_address_validators EmailValidator, RegexValidator, URLValidator, ip_address_validators
) )
from django.db import models
from django.forms import FilePathField as DjangoFilePathField from django.forms import FilePathField as DjangoFilePathField
from django.forms import ImageField as DjangoImageField from django.forms import ImageField as DjangoImageField
from django.utils import six, timezone from django.utils import six, timezone
@ -85,7 +87,7 @@ else:
return len_args <= len_defaults return len_args <= len_defaults
def get_attribute(instance, attrs): def get_attribute(instance, attrs, exc_on_model_default=False):
""" """
Similar to Python's built in `getattr(instance, attr)`, Similar to Python's built in `getattr(instance, attr)`,
but takes a list of nested attributes, instead of a single attribute. but takes a list of nested attributes, instead of a single attribute.
@ -96,6 +98,33 @@ def get_attribute(instance, attrs):
try: try:
if isinstance(instance, collections.Mapping): if isinstance(instance, collections.Mapping):
instance = instance[attr] instance = instance[attr]
elif exc_on_model_default and isinstance(instance, models.Model):
# Lookup the model field default
try:
field = instance._meta.get_field(attr)
except FieldDoesNotExist:
field = None
else:
if isinstance(field, ct_fields.GenericForeignKey):
# For generic relations, use the foreign key as the
# default
field = instance._meta.get_field(field.fk_field)
elif not hasattr(field, 'get_default') and hasattr(
field, 'target_field'):
# Some relationship fields don't have their own
# `get_default()`
field = field.target_field
value = getattr(instance, attr)
if field is not None:
default = field.get_default()
if default is not empty and value == default:
# Support skipping model fields. They always return
# at least the field default so there's no
# AttributeError unless we force it.
raise AttributeError(
'{0!r} object has no attribute {1!r}'.format(
instance, attr))
instance = value
else: else:
instance = getattr(instance, attr) instance = getattr(instance, attr)
except ObjectDoesNotExist: except ObjectDoesNotExist:
@ -438,7 +467,9 @@ class Field(object):
that should be used for this field. that should be used for this field.
""" """
try: try:
return get_attribute(instance, self.source_attrs) return get_attribute(
instance, self.source_attrs, exc_on_model_default=(
self.default is not empty or not self.required))
except (KeyError, AttributeError) as exc: except (KeyError, AttributeError) as exc:
if self.default is not empty: if self.default is not empty:
return self.get_default() return self.get_default()

View File

@ -176,7 +176,7 @@ class RelatedField(Field):
pass pass
# Standard case, return the object instance. # Standard case, return the object instance.
return get_attribute(instance, self.source_attrs) return super(RelatedField, self).get_attribute(instance)
def get_choices(self, cutoff=None): def get_choices(self, cutoff=None):
queryset = self.get_queryset() queryset = self.get_queryset()

View File

@ -44,7 +44,7 @@ class InheritedModelSerializationTests(TestCase):
""" """
child = ChildModel(name1='parent name', name2='child name') child = ChildModel(name1='parent name', name2='child name')
serializer = DerivedModelSerializer(child) serializer = DerivedModelSerializer(child)
assert set(serializer.data.keys()) == set(['name1', 'name2', 'id']) assert set(serializer.data.keys()) == set(['name1', 'name2'])
def test_onetoone_primary_key_model_fields_as_expected(self): def test_onetoone_primary_key_model_fields_as_expected(self):
""" """

View File

@ -43,4 +43,4 @@ class InheritedModelSerializationTests(TestCase):
child = ChildModel(name1='parent name', name2='child name') child = ChildModel(name1='parent name', name2='child name')
serializer = DerivedModelSerializer(child) serializer = DerivedModelSerializer(child)
self.assertEqual(set(serializer.data.keys()), self.assertEqual(set(serializer.data.keys()),
set(['name1', 'name2', 'id', 'childassociatedmodel'])) set(['name1', 'name2', 'childassociatedmodel']))

View File

@ -170,7 +170,7 @@ class PKManyToManyTests(TestCase):
serializer = ManyToManySourceSerializer(source) serializer = ManyToManySourceSerializer(source)
expected = {'id': None, 'name': 'source-unsaved', 'targets': []} expected = {'name': 'source-unsaved', 'targets': []}
# no query if source hasn't been created yet # no query if source hasn't been created yet
with self.assertNumQueries(0): with self.assertNumQueries(0):
assert serializer.data == expected assert serializer.data == expected
@ -330,7 +330,7 @@ class PKForeignKeyTests(TestCase):
def test_foreign_key_with_unsaved(self): def test_foreign_key_with_unsaved(self):
source = ForeignKeySource(name='source-unsaved') source = ForeignKeySource(name='source-unsaved')
expected = {'id': None, 'name': 'source-unsaved', 'target': None} expected = {'name': 'source-unsaved', 'target': None}
serializer = ForeignKeySourceSerializer(source) serializer = ForeignKeySourceSerializer(source)

View File

@ -360,12 +360,17 @@ class TestNotRequiredOutput:
""" """
'required=False' should allow an object attribute to be missing in output. 'required=False' should allow an object attribute to be missing in output.
""" """
class ExampleSerializer(serializers.Serializer): class MyModel(models.Model):
omitted = serializers.CharField(required=False) omitted = models.CharField(max_length=10, blank=True)
included = serializers.CharField() included = models.CharField(max_length=10)
class ExampleSerializer(serializers.ModelSerializer):
class Meta:
model = MyModel
exclude = ('id', )
def create(self, validated_data): def create(self, validated_data):
return MockObject(**validated_data) return self.Meta.model(**validated_data)
serializer = ExampleSerializer(data={'included': 'abc'}) serializer = ExampleSerializer(data={'included': 'abc'})
serializer.is_valid() serializer.is_valid()