From 80d3932faca310fe96b94306e3eb2fc720d851e4 Mon Sep 17 00:00:00 2001 From: Ross Patterson Date: Sun, 24 Sep 2017 14:33:12 -0700 Subject: [PATCH] Fix model fields not being omitted on output/serialization --- rest_framework/fields.py | 37 +++++++++++++++++++++-- rest_framework/relations.py | 2 +- tests/test_multitable_inheritance.py | 2 +- tests/test_one_to_one_with_inheritance.py | 2 +- tests/test_relations_pk.py | 4 +-- tests/test_serializer.py | 13 +++++--- 6 files changed, 48 insertions(+), 12 deletions(-) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 072bbf1b9..d4cda4777 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -11,12 +11,14 @@ import uuid from collections import OrderedDict from django.conf import settings +from django.contrib.contenttypes import fields as ct_fields from django.core.exceptions import ValidationError as DjangoValidationError -from django.core.exceptions import ObjectDoesNotExist +from django.core.exceptions import FieldDoesNotExist, ObjectDoesNotExist from django.core.validators import ( EmailValidator, MaxLengthValidator, MaxValueValidator, MinLengthValidator, MinValueValidator, RegexValidator, URLValidator, ip_address_validators ) +from django.db import models from django.forms import FilePathField as DjangoFilePathField from django.forms import ImageField as DjangoImageField from django.utils import six, timezone @@ -85,7 +87,7 @@ else: return len_args <= len_defaults -def get_attribute(instance, attrs): +def get_attribute(instance, attrs, exc_on_model_default=False): """ Similar to Python's built in `getattr(instance, attr)`, but takes a list of nested attributes, instead of a single attribute. @@ -96,6 +98,33 @@ def get_attribute(instance, attrs): try: if isinstance(instance, collections.Mapping): instance = instance[attr] + elif exc_on_model_default and isinstance(instance, models.Model): + # Lookup the model field default + try: + field = instance._meta.get_field(attr) + except FieldDoesNotExist: + field = None + else: + if isinstance(field, ct_fields.GenericForeignKey): + # For generic relations, use the foreign key as the + # default + field = instance._meta.get_field(field.fk_field) + elif not hasattr(field, 'get_default') and hasattr( + field, 'target_field'): + # Some relationship fields don't have their own + # `get_default()` + field = field.target_field + value = getattr(instance, attr) + if field is not None: + default = field.get_default() + if default is not empty and value == default: + # Support skipping model fields. They always return + # at least the field default so there's no + # AttributeError unless we force it. + raise AttributeError( + '{0!r} object has no attribute {1!r}'.format( + instance, attr)) + instance = value else: instance = getattr(instance, attr) except ObjectDoesNotExist: @@ -438,7 +467,9 @@ class Field(object): that should be used for this field. """ try: - return get_attribute(instance, self.source_attrs) + return get_attribute( + instance, self.source_attrs, exc_on_model_default=( + self.default is not empty or not self.required)) except (KeyError, AttributeError) as exc: if self.default is not empty: return self.get_default() diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 4d3bdba1d..93e9e9f63 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -176,7 +176,7 @@ class RelatedField(Field): pass # Standard case, return the object instance. - return get_attribute(instance, self.source_attrs) + return super(RelatedField, self).get_attribute(instance) def get_choices(self, cutoff=None): queryset = self.get_queryset() diff --git a/tests/test_multitable_inheritance.py b/tests/test_multitable_inheritance.py index 2aacbc348..ba342e751 100644 --- a/tests/test_multitable_inheritance.py +++ b/tests/test_multitable_inheritance.py @@ -44,7 +44,7 @@ class InheritedModelSerializationTests(TestCase): """ child = ChildModel(name1='parent name', name2='child name') serializer = DerivedModelSerializer(child) - assert set(serializer.data.keys()) == set(['name1', 'name2', 'id']) + assert set(serializer.data.keys()) == set(['name1', 'name2']) def test_onetoone_primary_key_model_fields_as_expected(self): """ diff --git a/tests/test_one_to_one_with_inheritance.py b/tests/test_one_to_one_with_inheritance.py index 9c489c1df..1170b63c1 100644 --- a/tests/test_one_to_one_with_inheritance.py +++ b/tests/test_one_to_one_with_inheritance.py @@ -43,4 +43,4 @@ class InheritedModelSerializationTests(TestCase): child = ChildModel(name1='parent name', name2='child name') serializer = DerivedModelSerializer(child) self.assertEqual(set(serializer.data.keys()), - set(['name1', 'name2', 'id', 'childassociatedmodel'])) + set(['name1', 'name2', 'childassociatedmodel'])) diff --git a/tests/test_relations_pk.py b/tests/test_relations_pk.py index 2eebe1b5c..68af29c45 100644 --- a/tests/test_relations_pk.py +++ b/tests/test_relations_pk.py @@ -170,7 +170,7 @@ class PKManyToManyTests(TestCase): serializer = ManyToManySourceSerializer(source) - expected = {'id': None, 'name': 'source-unsaved', 'targets': []} + expected = {'name': 'source-unsaved', 'targets': []} # no query if source hasn't been created yet with self.assertNumQueries(0): assert serializer.data == expected @@ -330,7 +330,7 @@ class PKForeignKeyTests(TestCase): def test_foreign_key_with_unsaved(self): source = ForeignKeySource(name='source-unsaved') - expected = {'id': None, 'name': 'source-unsaved', 'target': None} + expected = {'name': 'source-unsaved', 'target': None} serializer = ForeignKeySourceSerializer(source) diff --git a/tests/test_serializer.py b/tests/test_serializer.py index af5206a9f..e67e592d8 100644 --- a/tests/test_serializer.py +++ b/tests/test_serializer.py @@ -360,12 +360,17 @@ class TestNotRequiredOutput: """ 'required=False' should allow an object attribute to be missing in output. """ - class ExampleSerializer(serializers.Serializer): - omitted = serializers.CharField(required=False) - included = serializers.CharField() + class MyModel(models.Model): + omitted = models.CharField(max_length=10, blank=True) + included = models.CharField(max_length=10) + + class ExampleSerializer(serializers.ModelSerializer): + class Meta: + model = MyModel + exclude = ('id', ) def create(self, validated_data): - return MockObject(**validated_data) + return self.Meta.model(**validated_data) serializer = ExampleSerializer(data={'included': 'abc'}) serializer.is_valid()