This commit is contained in:
eofs 2013-01-29 08:39:26 -08:00
commit 8ba586e7d6
5 changed files with 107 additions and 2 deletions

View File

@ -237,6 +237,16 @@ The `RelatedField` class may be subclassed to create a custom representation of
All the relational fields may be used for any relationship or reverse relationship on a model. All the relational fields may be used for any relationship or reverse relationship on a model.
## Reverse relational fields
By default reverse relational fields are not displayed when ModelSerializer is used. You can control this behavior by using `DEFAULT_INCLUDE_REVERSE_RELATIONS` setting.
Besides global setting you can also use model specific setting:
class BlogPostSerializer(serializer.ModelSerializer):
class Meta:
include_reverse_relations = True
## Specifying which fields should be included ## Specifying which fields should be included
If you only want a subset of the default fields to be used in a model serializer, you can do so using `fields` or `exclude` options, just as you would with a `ModelForm`. If you only want a subset of the default fields to be used in a model serializer, you can do so using `fields` or `exclude` options, just as you would with a `ModelForm`.

View File

@ -96,6 +96,12 @@ Default: `rest_framework.serializers.ModelSerializer`
Default: `rest_framework.pagination.PaginationSerializer` Default: `rest_framework.pagination.PaginationSerializer`
## DEFAULT_INCLUDE_REVERSE_RELATIONS
If set to `True`, ModelSerializer will display reverse relational fields from other models.
Default: `False`
## FILTER_BACKEND ## FILTER_BACKEND
The filter backend class that should be used for generic filtering. If set to `None` then generic filtering is disabled. The filter backend class that should be used for generic filtering. If set to `None` then generic filtering is disabled.

View File

@ -7,6 +7,7 @@ from django.db import models
from django.forms import widgets from django.forms import widgets
from django.utils.datastructures import SortedDict from django.utils.datastructures import SortedDict
from rest_framework.compat import get_concrete_model from rest_framework.compat import get_concrete_model
from rest_framework.settings import api_settings
# Note: We do the following so that users of the framework can use this style: # Note: We do the following so that users of the framework can use this style:
# #
@ -371,6 +372,7 @@ class ModelSerializerOptions(SerializerOptions):
super(ModelSerializerOptions, self).__init__(meta) super(ModelSerializerOptions, self).__init__(meta)
self.model = getattr(meta, 'model', None) self.model = getattr(meta, 'model', None)
self.read_only_fields = getattr(meta, 'read_only_fields', ()) self.read_only_fields = getattr(meta, 'read_only_fields', ())
self.include_reverse_relations = getattr(meta, 'include_reverse_relations', api_settings.DEFAULT_INCLUDE_REVERSE_RELATIONS)
class ModelSerializer(Serializer): class ModelSerializer(Serializer):
@ -379,6 +381,24 @@ class ModelSerializer(Serializer):
""" """
_options_class = ModelSerializerOptions _options_class = ModelSerializerOptions
def get_reverse_fields(self, opts, fields):
# Construct a list of all relations
relations = []
relations += [obj for obj in opts.get_all_related_objects() if obj.field.serialize]
relations += [obj for obj in opts.get_all_related_many_to_many_objects() if obj.field.serialize]
# Construct a list of intermediate models
exclude = []
for field in fields:
if field.rel and hasattr(field.rel, 'through'):
exclude.append(field.rel.through)
# Intermediate models from reverse relations
for rel in relations:
if rel.field.rel and hasattr(rel.field.rel, 'through'):
exclude.append(rel.field.rel.through)
return [rel.field for rel in relations if rel.model not in exclude]
def get_default_fields(self): def get_default_fields(self):
""" """
Return all the fields that should be serialized for the model. Return all the fields that should be serialized for the model.
@ -393,6 +413,11 @@ class ModelSerializer(Serializer):
fields += [field for field in opts.fields if field.serialize] fields += [field for field in opts.fields if field.serialize]
fields += [field for field in opts.many_to_many if field.serialize] fields += [field for field in opts.many_to_many if field.serialize]
reverse_fields = []
if self.opts.include_reverse_relations:
reverse_fields = self.get_reverse_fields(opts, fields)
fields += reverse_fields
ret = SortedDict() ret = SortedDict()
nested = bool(self.opts.depth) nested = bool(self.opts.depth)
is_pk = True # First field in the list is the pk is_pk = True # First field in the list is the pk
@ -406,11 +431,20 @@ class ModelSerializer(Serializer):
elif model_field.rel: elif model_field.rel:
to_many = isinstance(model_field, to_many = isinstance(model_field,
models.fields.related.ManyToManyField) models.fields.related.ManyToManyField)
# Reverse relational fields must be dealt as Many fields
if model_field.model is not self.opts.model:
to_many = True
field = self.get_related_field(model_field, to_many=to_many) field = self.get_related_field(model_field, to_many=to_many)
else: else:
field = self.get_field(model_field) field = self.get_field(model_field)
if field: if field:
if model_field in reverse_fields:
# Get user set 'related_name' or automatically set field
# name e.g. 'comment_set'
name = model_field.related.get_accessor_name()
ret[name] = field
else:
ret[model_field.name] = field ret[model_field.name] = field
for field_name in self.opts.read_only_fields: for field_name in self.opts.read_only_fields:
@ -431,9 +465,17 @@ class ModelSerializer(Serializer):
""" """
Creates a default instance of a nested relational field. Creates a default instance of a nested relational field.
""" """
# Field has reverse relation if it's referring to different model
if self.opts.model is not model_field.rel.to:
# Get correct model from the relation
model_class = model_field.rel.to
else:
# Forward relation, no need for magic
model_class = model_field.model
class NestedModelSerializer(ModelSerializer): class NestedModelSerializer(ModelSerializer):
class Meta: class Meta:
model = model_field.rel.to model = model_class
return NestedModelSerializer() return NestedModelSerializer()
def get_related_field(self, model_field, to_many=False): def get_related_field(self, model_field, to_many=False):

View File

@ -55,6 +55,9 @@ DEFAULTS = {
'anon': None, 'anon': None,
}, },
# ModelSerializer
'DEFAULT_INCLUDE_REVERSE_RELATIONS': False,
# Pagination # Pagination
'PAGINATE_BY': None, 'PAGINATE_BY': None,
'PAGINATE_BY_PARAM': None, 'PAGINATE_BY_PARAM': None,

View File

@ -695,6 +695,50 @@ class ManyRelatedTests(TestCase):
self.assertEqual(serializer.data, expected) self.assertEqual(serializer.data, expected)
def test_include_reverse_relations(self):
post = BlogPost.objects.create(title="Test blog post")
post.blogpostcomment_set.create(text="I hate this blog post")
post.blogpostcomment_set.create(text="I love this blog post")
class BlogPostCommentSerializer(serializers.Serializer):
text = serializers.CharField()
class BlogPostSerializer(serializers.ModelSerializer):
class Meta:
model = BlogPost
include_reverse_relations = True
serializer = BlogPostSerializer(instance=post)
expected = {
'id': 1, 'title': u'Test blog post', 'writer': None,
'blogpostcomment_set': [1, 2]
}
self.assertEqual(serializer.data, expected)
def test_depth_include_reverse_relations(self):
post = BlogPost.objects.create(title="Test blog post")
post.blogpostcomment_set.create(text="I hate this blog post")
post.blogpostcomment_set.create(text="I love this blog post")
class BlogPostCommentSerializer(serializers.Serializer):
text = serializers.CharField()
class BlogPostSerializer(serializers.ModelSerializer):
class Meta:
model = BlogPost
include_reverse_relations = True
depth = 1
serializer = BlogPostSerializer(instance=post)
expected = {
'id': 1, 'title': u'Test blog post', 'writer': None,
'blogpostcomment_set': [
{'id': 1, 'text': u'I hate this blog post', 'blog_post': 1},
{'id': 2, 'text': u'I love this blog post', 'blog_post': 1}
]
}
self.assertEqual(serializer.data, expected)
def test_callable_source(self): def test_callable_source(self):
post = BlogPost.objects.create(title="Test blog post") post = BlogPost.objects.create(title="Test blog post")
post.blogpostcomment_set.create(text="I love this blog post") post.blogpostcomment_set.create(text="I love this blog post")