diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index da0af467f..624b69d8b 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -92,7 +92,6 @@ class SerializerOptions(object): self.fields = getattr(meta, 'fields', ()) self.exclude = getattr(meta, 'exclude', ()) - class BaseSerializer(Field): class Meta(object): pass @@ -361,6 +360,7 @@ class ModelSerializerOptions(SerializerOptions): super(ModelSerializerOptions, self).__init__(meta) self.model = getattr(meta, 'model', None) self.read_only_fields = getattr(meta, 'read_only_fields', ()) + self.include_reversed_relations = getattr(meta, 'include_reversed_relations', False) class ModelSerializer(Serializer): @@ -383,6 +383,12 @@ class ModelSerializer(Serializer): fields += [field for field in opts.fields if field.serialize] fields += [field for field in opts.many_to_many if field.serialize] + reversed_fields = () + if self.opts.include_reversed_relations: + reversed_fields = [obj.field for obj in opts.get_all_related_objects() if obj.field.serialize] + reversed_fields = [obj.field for obj in opts.get_all_related_many_to_many_objects() if obj.field.serialize] + fields += reversed_fields + ret = SortedDict() nested = bool(self.opts.depth) is_pk = True # First field in the list is the pk @@ -401,7 +407,10 @@ class ModelSerializer(Serializer): field = self.get_field(model_field) if field: - ret[model_field.name] = field + if model_field in reversed_fields: + ret[model_field.rel.related_name] = field + else: + ret[model_field.name] = field for field_name in self.opts.read_only_fields: assert field_name in ret, \ @@ -421,9 +430,12 @@ class ModelSerializer(Serializer): """ Creates a default instance of a nested relational field. """ + + # If field is reversed relation, get model from relation + obj_model = model_field.rel.to if self.opts.model is not model_field.rel.to else model_field.model class NestedModelSerializer(ModelSerializer): class Meta: - model = model_field.rel.to + model = obj_model return NestedModelSerializer() def get_related_field(self, model_field, to_many=False):