diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 35c00bf1d..072c4af6d 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -144,13 +144,17 @@ class RelatedField(WritableField): return None if self.many: - if is_simple_callable(getattr(value, 'all', None)): - return [self.to_native(item) for item in value.all()] - else: - # Also support non-queryset iterables. - # This allows us to also support plain lists of related items. - return [self.to_native(item) for item in value] - return self.to_native(value) + return self.value_to_native(value) + else: + return self.to_native(value) + + def value_to_native(self, value): + if is_simple_callable(getattr(value, 'all', None)): + return [self.to_native(item) for item in value.all()] + else: + # Also support non-queryset iterables. + # This allows us to also support plain lists of related items. + return [self.to_native(item) for item in value] def field_from_native(self, data, files, field_name, into): if self.read_only: @@ -249,12 +253,7 @@ class PrimaryKeyRelatedField(RelatedField): queryset = get_component(queryset, component) # Forward relationship - if is_simple_callable(getattr(queryset, 'all', None)): - return [self.to_native(item.pk) for item in queryset.all()] - else: - # Also support non-queryset iterables. - # This allows us to also support plain lists of related items. - return [self.to_native(item.pk) for item in queryset] + return self.value_to_native(queryset) # To-one relationship try: @@ -270,6 +269,14 @@ class PrimaryKeyRelatedField(RelatedField): # Forward relationship return self.to_native(pk) + def value_to_native(self, value): + if is_simple_callable(getattr(value, 'all', None)): + return [self.to_native(item.pk) for item in value.all()] + else: + # Also support non-queryset iterables. + # This allows us to also support plain lists of related items. + return [self.to_native(item.pk) for item in value] + ### Slug relationships diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 4210d058b..8b57f3531 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -351,6 +351,12 @@ class BaseSerializer(WritableField): field.label = pretty_name(key) return field + def value_from_native(self, value): + return value + + def value_to_native(self, value): + return value + def field_to_native(self, obj, field_name): """ Override default so that the serializer can be used as a nested field @@ -371,8 +377,9 @@ class BaseSerializer(WritableField): except ObjectDoesNotExist: return None - if is_simple_callable(getattr(value, 'all', None)): - return [self.to_native(item) for item in value.all()] + native_value = self.value_to_native(value) + if native_value is not None: + return native_value if value is None: return None @@ -407,7 +414,7 @@ class BaseSerializer(WritableField): # Set the serializer object if it exists obj = get_component(self.parent.object, self.source or field_name) if self.parent.object else None - obj = obj.all() if is_simple_callable(getattr(obj, 'all', None)) else obj + obj = self.value_from_native(obj) if self.source == '*': if value: @@ -905,6 +912,16 @@ class ModelSerializer(Serializer): return instance + def value_from_native(self, value): + if is_simple_callable(getattr(value, 'all', None)): + return value.all() + else: + return value + + def value_to_native(self, value): + if is_simple_callable(getattr(value, 'all', None)): + return [self.to_native(item) for item in value.all()] + def from_native(self, data, files): """ Override the default method to also include model field validation.