Merge branch 'include_reverse_relations' of https://github.com/tomchristie/django-rest-framework into include_reverse_relations

This commit is contained in:
Tom Christie 2013-04-17 09:26:34 +01:00
commit bcf4cb2b4e
2 changed files with 74 additions and 6 deletions

View File

@ -598,6 +598,24 @@ class ModelSerializer(Serializer):
if field: if field:
ret[model_field.name] = field ret[model_field.name] = field
# Reverse relationships are only included if they are explicitly
# present in `Meta.fields`.
if self.opts.fields:
reverse = opts.get_all_related_objects()
reverse += opts.get_all_related_many_to_many_objects()
for rel in reverse:
name = rel.get_accessor_name()
if name not in self.opts.fields:
continue
if nested:
field = self.get_nested_field(None, rel)
else:
field = self.get_related_field(None, rel, to_many=True)
if field:
ret[name] = field
for field_name in self.opts.read_only_fields: for field_name in self.opts.read_only_fields:
assert field_name in ret, \ assert field_name in ret, \
"read_only_fields on '%s' included invalid item '%s'" % \ "read_only_fields on '%s' included invalid item '%s'" % \
@ -612,24 +630,36 @@ class ModelSerializer(Serializer):
""" """
return self.get_field(model_field) return self.get_field(model_field)
def get_nested_field(self, model_field): def get_nested_field(self, model_field, rel=None):
""" """
Creates a default instance of a nested relational field. Creates a default instance of a nested relational field.
""" """
if rel:
model_class = rel.model
else:
model_class = model_field.rel.to
class NestedModelSerializer(ModelSerializer): class NestedModelSerializer(ModelSerializer):
class Meta: class Meta:
model = model_field.rel.to model = model_class
return NestedModelSerializer() return NestedModelSerializer()
def get_related_field(self, model_field, to_many=False): def get_related_field(self, model_field, rel=None, to_many=False):
""" """
Creates a default instance of a flat relational field. Creates a default instance of a flat relational field.
""" """
# TODO: filter queryset using: # TODO: filter queryset using:
# .using(db).complex_filter(self.rel.limit_choices_to) # .using(db).complex_filter(self.rel.limit_choices_to)
if rel:
model_class = rel.model
required = True
else:
model_class = model_field.rel.to
required = not(model_field.null or model_field.blank)
kwargs = { kwargs = {
'required': not(model_field.null or model_field.blank), 'required': required,
'queryset': model_field.rel.to._default_manager, 'queryset': model_class._default_manager,
'many': to_many 'many': to_many
} }
@ -797,7 +827,8 @@ class HyperlinkedModelSerializer(ModelSerializer):
return self._default_view_name % format_kwargs return self._default_view_name % format_kwargs
def get_pk_field(self, model_field): def get_pk_field(self, model_field):
return None if self.opts.fields and model_field.name in self.opts.fields:
return self.get_field(model_field)
def get_related_field(self, model_field, to_many): def get_related_field(self, model_field, to_many):
""" """

View File

@ -738,6 +738,43 @@ class ManyRelatedTests(TestCase):
self.assertEqual(serializer.data, expected) self.assertEqual(serializer.data, expected)
def test_include_reverse_relations(self):
post = BlogPost.objects.create(title="Test blog post")
post.blogpostcomment_set.create(text="I hate this blog post")
post.blogpostcomment_set.create(text="I love this blog post")
class BlogPostSerializer(serializers.ModelSerializer):
class Meta:
model = BlogPost
fields = ('id', 'title', 'blogpostcomment_set')
serializer = BlogPostSerializer(instance=post)
expected = {
'id': 1, 'title': 'Test blog post', 'blogpostcomment_set': [1, 2]
}
self.assertEqual(serializer.data, expected)
def test_depth_include_reverse_relations(self):
post = BlogPost.objects.create(title="Test blog post")
post.blogpostcomment_set.create(text="I hate this blog post")
post.blogpostcomment_set.create(text="I love this blog post")
class BlogPostSerializer(serializers.ModelSerializer):
class Meta:
model = BlogPost
fields = ('id', 'title', 'blogpostcomment_set')
depth = 1
serializer = BlogPostSerializer(instance=post)
expected = {
'id': 1, 'title': 'Test blog post',
'blogpostcomment_set': [
{'id': 1, 'text': 'I hate this blog post', 'blog_post': 1},
{'id': 2, 'text': 'I love this blog post', 'blog_post': 1}
]
}
self.assertEqual(serializer.data, expected)
def test_callable_source(self): def test_callable_source(self):
post = BlogPost.objects.create(title="Test blog post") post = BlogPost.objects.create(title="Test blog post")
post.blogpostcomment_set.create(text="I love this blog post") post.blogpostcomment_set.create(text="I love this blog post")