From 76e039d70e8fc7f1d5c65180cb544abab81e600e Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 10 Apr 2013 22:38:02 +0100 Subject: [PATCH] First pass on automatically including reverse relationship --- rest_framework/serializers.py | 43 +++++++++++++++++++++++++----- rest_framework/tests/serializer.py | 37 +++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 6 deletions(-) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index e28bbe81d..eac909c7b 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -598,6 +598,24 @@ class ModelSerializer(Serializer): if field: ret[model_field.name] = field + # Reverse relationships are only included if they are explicitly + # present in `Meta.fields`. + if self.opts.fields: + reverse = opts.get_all_related_objects() + reverse += opts.get_all_related_many_to_many_objects() + for rel in reverse: + name = rel.get_accessor_name() + if name not in self.opts.fields: + continue + + if nested: + field = self.get_nested_field(None, rel) + else: + field = self.get_related_field(None, rel, to_many=True) + + if field: + ret[name] = field + for field_name in self.opts.read_only_fields: assert field_name in ret, \ "read_only_fields on '%s' included invalid item '%s'" % \ @@ -612,24 +630,36 @@ class ModelSerializer(Serializer): """ return self.get_field(model_field) - def get_nested_field(self, model_field): + def get_nested_field(self, model_field, rel=None): """ Creates a default instance of a nested relational field. """ + if rel: + model_class = rel.model + else: + model_class = model_field.rel.to + class NestedModelSerializer(ModelSerializer): class Meta: - model = model_field.rel.to + model = model_class return NestedModelSerializer() - def get_related_field(self, model_field, to_many=False): + def get_related_field(self, model_field, rel=None, to_many=False): """ Creates a default instance of a flat relational field. """ # TODO: filter queryset using: # .using(db).complex_filter(self.rel.limit_choices_to) + if rel: + model_class = rel.model + required = True + else: + model_class = model_field.rel.to + required = not(model_field.null or model_field.blank) + kwargs = { - 'required': not(model_field.null or model_field.blank), - 'queryset': model_field.rel.to._default_manager, + 'required': required, + 'queryset': model_class._default_manager, 'many': to_many } @@ -797,7 +827,8 @@ class HyperlinkedModelSerializer(ModelSerializer): return self._default_view_name % format_kwargs def get_pk_field(self, model_field): - return None + if self.opts.fields and model_field.name in self.opts.fields: + return self.get_field(model_field) def get_related_field(self, model_field, to_many): """ diff --git a/rest_framework/tests/serializer.py b/rest_framework/tests/serializer.py index 05217f35a..3a94fad5d 100644 --- a/rest_framework/tests/serializer.py +++ b/rest_framework/tests/serializer.py @@ -738,6 +738,43 @@ class ManyRelatedTests(TestCase): self.assertEqual(serializer.data, expected) + def test_include_reverse_relations(self): + post = BlogPost.objects.create(title="Test blog post") + post.blogpostcomment_set.create(text="I hate this blog post") + post.blogpostcomment_set.create(text="I love this blog post") + + class BlogPostSerializer(serializers.ModelSerializer): + class Meta: + model = BlogPost + fields = ('id', 'title', 'blogpostcomment_set') + + serializer = BlogPostSerializer(instance=post) + expected = { + 'id': 1, 'title': 'Test blog post', 'blogpostcomment_set': [1, 2] + } + self.assertEqual(serializer.data, expected) + + def test_depth_include_reverse_relations(self): + post = BlogPost.objects.create(title="Test blog post") + post.blogpostcomment_set.create(text="I hate this blog post") + post.blogpostcomment_set.create(text="I love this blog post") + + class BlogPostSerializer(serializers.ModelSerializer): + class Meta: + model = BlogPost + fields = ('id', 'title', 'blogpostcomment_set') + depth = 1 + + serializer = BlogPostSerializer(instance=post) + expected = { + 'id': 1, 'title': 'Test blog post', + 'blogpostcomment_set': [ + {'id': 1, 'text': 'I hate this blog post', 'blog_post': 1}, + {'id': 2, 'text': 'I love this blog post', 'blog_post': 1} + ] + } + self.assertEqual(serializer.data, expected) + def test_callable_source(self): post = BlogPost.objects.create(title="Test blog post") post.blogpostcomment_set.create(text="I love this blog post")