mirror of
				https://github.com/encode/django-rest-framework.git
				synced 2025-11-04 09:57:55 +03:00 
			
		
		
		
	Handle unset fields with 'many=True' (#7574)
* Handle unset fields with 'many=True' The docs note: When serializing fields with dotted notation, it may be necessary to provide a `default` value if any object is not present or is empty during attribute traversal. However, this doesn't work for fields with 'many=True'. When using these, the default is simply ignored. The solution is simple: do in 'ManyRelatedField' what we were already doing for 'Field', namely, catch possible 'AttributeError' and 'KeyError' exceptions and return the default if there is one set. Signed-off-by: Stephen Finucane <stephen@that.guru> Closes: #7550 * Add test cases for #7550 Signed-off-by: Stephen Finucane <stephen@that.guru>
This commit is contained in:
		
							parent
							
								
									26830c3d2d
								
							
						
					
					
						commit
						5185cc9348
					
				| 
						 | 
					@ -10,7 +10,7 @@ from django.utils.encoding import smart_str, uri_to_iri
 | 
				
			||||||
from django.utils.translation import gettext_lazy as _
 | 
					from django.utils.translation import gettext_lazy as _
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from rest_framework.fields import (
 | 
					from rest_framework.fields import (
 | 
				
			||||||
    Field, empty, get_attribute, is_simple_callable, iter_options
 | 
					    Field, SkipField, empty, get_attribute, is_simple_callable, iter_options
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from rest_framework.reverse import reverse
 | 
					from rest_framework.reverse import reverse
 | 
				
			||||||
from rest_framework.settings import api_settings
 | 
					from rest_framework.settings import api_settings
 | 
				
			||||||
| 
						 | 
					@ -535,7 +535,30 @@ class ManyRelatedField(Field):
 | 
				
			||||||
        if hasattr(instance, 'pk') and instance.pk is None:
 | 
					        if hasattr(instance, 'pk') and instance.pk is None:
 | 
				
			||||||
            return []
 | 
					            return []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        relationship = get_attribute(instance, self.source_attrs)
 | 
					        try:
 | 
				
			||||||
 | 
					            relationship = get_attribute(instance, self.source_attrs)
 | 
				
			||||||
 | 
					        except (KeyError, AttributeError) as exc:
 | 
				
			||||||
 | 
					            if self.default is not empty:
 | 
				
			||||||
 | 
					                return self.get_default()
 | 
				
			||||||
 | 
					            if self.allow_null:
 | 
				
			||||||
 | 
					                return None
 | 
				
			||||||
 | 
					            if not self.required:
 | 
				
			||||||
 | 
					                raise SkipField()
 | 
				
			||||||
 | 
					            msg = (
 | 
				
			||||||
 | 
					                'Got {exc_type} when attempting to get a value for field '
 | 
				
			||||||
 | 
					                '`{field}` on serializer `{serializer}`.\nThe serializer '
 | 
				
			||||||
 | 
					                'field might be named incorrectly and not match '
 | 
				
			||||||
 | 
					                'any attribute or key on the `{instance}` instance.\n'
 | 
				
			||||||
 | 
					                'Original exception text was: {exc}.'.format(
 | 
				
			||||||
 | 
					                    exc_type=type(exc).__name__,
 | 
				
			||||||
 | 
					                    field=self.field_name,
 | 
				
			||||||
 | 
					                    serializer=self.parent.__class__.__name__,
 | 
				
			||||||
 | 
					                    instance=instance.__class__.__name__,
 | 
				
			||||||
 | 
					                    exc=exc
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            raise type(exc)(msg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return relationship.all() if hasattr(relationship, 'all') else relationship
 | 
					        return relationship.all() if hasattr(relationship, 'all') else relationship
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def to_representation(self, iterable):
 | 
					    def to_representation(self, iterable):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1025,6 +1025,73 @@ class Issue2704TestCase(TestCase):
 | 
				
			||||||
        assert serializer.data == expected
 | 
					        assert serializer.data == expected
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Issue7550FooModel(models.Model):
 | 
				
			||||||
 | 
					    text = models.CharField(max_length=100)
 | 
				
			||||||
 | 
					    bar = models.ForeignKey(
 | 
				
			||||||
 | 
					        'Issue7550BarModel', null=True, blank=True, on_delete=models.SET_NULL,
 | 
				
			||||||
 | 
					        related_name='foos', related_query_name='foo')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Issue7550BarModel(models.Model):
 | 
				
			||||||
 | 
					    pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Issue7550TestCase(TestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_dotted_source(self):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        class _FooSerializer(serializers.ModelSerializer):
 | 
				
			||||||
 | 
					            class Meta:
 | 
				
			||||||
 | 
					                model = Issue7550FooModel
 | 
				
			||||||
 | 
					                fields = ('id', 'text')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        class FooSerializer(serializers.ModelSerializer):
 | 
				
			||||||
 | 
					            other_foos = _FooSerializer(source='bar.foos', many=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            class Meta:
 | 
				
			||||||
 | 
					                model = Issue7550BarModel
 | 
				
			||||||
 | 
					                fields = ('id', 'other_foos')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        bar = Issue7550BarModel.objects.create()
 | 
				
			||||||
 | 
					        foo_a = Issue7550FooModel.objects.create(bar=bar, text='abc')
 | 
				
			||||||
 | 
					        foo_b = Issue7550FooModel.objects.create(bar=bar, text='123')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        assert FooSerializer(foo_a).data == {
 | 
				
			||||||
 | 
					            'id': foo_a.id,
 | 
				
			||||||
 | 
					            'other_foos': [
 | 
				
			||||||
 | 
					                {
 | 
				
			||||||
 | 
					                    'id': foo_a.id,
 | 
				
			||||||
 | 
					                    'text': foo_a.text,
 | 
				
			||||||
 | 
					                },
 | 
				
			||||||
 | 
					                {
 | 
				
			||||||
 | 
					                    'id': foo_b.id,
 | 
				
			||||||
 | 
					                    'text': foo_b.text,
 | 
				
			||||||
 | 
					                },
 | 
				
			||||||
 | 
					            ],
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_dotted_source_with_default(self):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        class _FooSerializer(serializers.ModelSerializer):
 | 
				
			||||||
 | 
					            class Meta:
 | 
				
			||||||
 | 
					                model = Issue7550FooModel
 | 
				
			||||||
 | 
					                fields = ('id', 'text')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        class FooSerializer(serializers.ModelSerializer):
 | 
				
			||||||
 | 
					            other_foos = _FooSerializer(source='bar.foos', default=[], many=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            class Meta:
 | 
				
			||||||
 | 
					                model = Issue7550FooModel
 | 
				
			||||||
 | 
					                fields = ('id', 'other_foos')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        foo = Issue7550FooModel.objects.create(bar=None, text='abc')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        assert FooSerializer(foo).data == {
 | 
				
			||||||
 | 
					            'id': foo.id,
 | 
				
			||||||
 | 
					            'other_foos': [],
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class DecimalFieldModel(models.Model):
 | 
					class DecimalFieldModel(models.Model):
 | 
				
			||||||
    decimal_field = models.DecimalField(
 | 
					    decimal_field = models.DecimalField(
 | 
				
			||||||
        max_digits=3,
 | 
					        max_digits=3,
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user