From 9fdb2280d11db126771686d626aa8a0247b8a46c Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 18 Sep 2014 14:23:00 +0100 Subject: [PATCH] First pass on ManyRelation --- rest_framework/relations.py | 42 +++++++++++++++++++- rest_framework/utils/representation.py | 2 + tests/test_model_serializer.py | 54 ++++++++++++++++++++++---- 3 files changed, 89 insertions(+), 9 deletions(-) diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 9f44ab633..474d3e757 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -10,7 +10,6 @@ from django.utils.translation import ugettext_lazy as _ class RelatedField(Field): def __init__(self, **kwargs): self.queryset = kwargs.pop('queryset', None) - self.many = kwargs.pop('many', False) assert self.queryset is not None or kwargs.get('read_only', None), ( 'Relational field must provide a `queryset` argument, ' 'or set read_only=`True`.' @@ -21,6 +20,13 @@ class RelatedField(Field): ) super(RelatedField, self).__init__(**kwargs) + def __new__(cls, *args, **kwargs): + # We override this method in order to automagically create + # `ManyRelation` classes instead when `many=True` is set. + if kwargs.pop('many', False): + return ManyRelation(child_relation=cls(*args, **kwargs)) + return super(RelatedField, cls).__new__(cls, *args, **kwargs) + def get_queryset(self): queryset = self.queryset if isinstance(queryset, QuerySet): @@ -216,3 +222,37 @@ class SlugRelatedField(RelatedField): def to_representation(self, obj): return getattr(obj, self.slug_field) + + +class ManyRelation(Field): + """ + Relationships with `many=True` transparently get coerced into instead being + a ManyRelation with a child relationship. + + The `ManyRelation` class is responsible for handling iterating through + the values and passing each one to the child relationship. + + You shouldn't need to be using this class directly yourself. + """ + + def __init__(self, child_relation=None, *args, **kwargs): + self.child_relation = child_relation + assert child_relation is not None, '`child_relation` is a required argument.' + super(ManyRelation, self).__init__(*args, **kwargs) + + def bind(self, field_name, parent, root): + # ManyRelation needs to provide the current context to the child relation. + super(ManyRelation, self).bind(field_name, parent, root) + self.child_relation.bind(field_name, parent, root) + + def to_internal_value(self, data): + return [ + self.child_relation.to_internal_value(item) + for item in data + ] + + def to_representation(self, obj): + return [ + self.child_relation.to_representation(value) + for value in obj.all() + ] diff --git a/rest_framework/utils/representation.py b/rest_framework/utils/representation.py index 71db18863..e64fdd223 100644 --- a/rest_framework/utils/representation.py +++ b/rest_framework/utils/representation.py @@ -73,6 +73,8 @@ def serializer_repr(serializer, indent, force_many=None): ret += serializer_repr(field, indent + 1) elif hasattr(field, 'child'): ret += list_repr(field, indent + 1) + elif hasattr(field, 'child_relation'): + ret += field_repr(field.child_relation, force_many=field.child_relation) else: ret += field_repr(field) return ret diff --git a/tests/test_model_serializer.py b/tests/test_model_serializer.py index 3ee91126f..b3dae7134 100644 --- a/tests/test_model_serializer.py +++ b/tests/test_model_serializer.py @@ -199,7 +199,7 @@ class RelationalModel(models.Model): class TestRelationalFieldMappings(TestCase): - def test_flat_relational_fields(self): + def test_pk_relations(self): class TestSerializer(serializers.ModelSerializer): class Meta: model = RelationalModel @@ -214,7 +214,7 @@ class TestRelationalFieldMappings(TestCase): """) self.assertEqual(repr(TestSerializer()), expected) - def test_nested_relational_fields(self): + def test_nested_relations(self): class TestSerializer(serializers.ModelSerializer): class Meta: model = RelationalModel @@ -238,7 +238,7 @@ class TestRelationalFieldMappings(TestCase): """) self.assertEqual(repr(TestSerializer()), expected) - def test_flat_hyperlinked_fields(self): + def test_hyperlinked_relations(self): class TestSerializer(serializers.HyperlinkedModelSerializer): class Meta: model = RelationalModel @@ -253,7 +253,7 @@ class TestRelationalFieldMappings(TestCase): """) self.assertEqual(repr(TestSerializer()), expected) - def test_nested_hyperlinked_fields(self): + def test_nested_hyperlinked_relations(self): class TestSerializer(serializers.HyperlinkedModelSerializer): class Meta: model = RelationalModel @@ -277,7 +277,7 @@ class TestRelationalFieldMappings(TestCase): """) self.assertEqual(repr(TestSerializer()), expected) - def test_flat_reverse_foreign_key(self): + def test_pk_reverse_foreign_key(self): class TestSerializer(serializers.ModelSerializer): class Meta: model = ForeignKeyTargetModel @@ -291,7 +291,7 @@ class TestRelationalFieldMappings(TestCase): """) self.assertEqual(repr(TestSerializer()), expected) - def test_flat_reverse_one_to_one(self): + def test_pk_reverse_one_to_one(self): class TestSerializer(serializers.ModelSerializer): class Meta: model = OneToOneTargetModel @@ -305,7 +305,7 @@ class TestRelationalFieldMappings(TestCase): """) self.assertEqual(repr(TestSerializer()), expected) - def test_flat_reverse_many_to_many(self): + def test_pk_reverse_many_to_many(self): class TestSerializer(serializers.ModelSerializer): class Meta: model = ManyToManyTargetModel @@ -319,7 +319,7 @@ class TestRelationalFieldMappings(TestCase): """) self.assertEqual(repr(TestSerializer()), expected) - def test_flat_reverse_through(self): + def test_pk_reverse_through(self): class TestSerializer(serializers.ModelSerializer): class Meta: model = ThroughTargetModel @@ -332,3 +332,41 @@ class TestRelationalFieldMappings(TestCase): reverse_through = PrimaryKeyRelatedField(many=True, read_only=True) """) self.assertEqual(repr(TestSerializer()), expected) + + +class TestIntegration(TestCase): + def setUp(self): + self.foreign_key_target = ForeignKeyTargetModel.objects.create( + name='foreign_key' + ) + self.one_to_one_target = OneToOneTargetModel.objects.create( + name='one_to_one' + ) + self.many_to_many_targets = [ + ManyToManyTargetModel.objects.create( + name='many_to_many (%d)' % idx + ) for idx in range(3) + ] + self.instance = RelationalModel.objects.create( + foreign_key=self.foreign_key_target, + one_to_one=self.one_to_one_target, + ) + self.instance.many_to_many = self.many_to_many_targets + self.instance.save() + + class TestSerializer(serializers.ModelSerializer): + class Meta: + model = RelationalModel + + self.serializer_cls = TestSerializer + + def test_pk_relationship_representations(self): + serializer = self.serializer_cls(self.instance) + expected = { + 'id': self.instance.pk, + 'foreign_key': self.foreign_key_target.pk, + 'one_to_one': self.one_to_one_target.pk, + 'many_to_many': [item.pk for item in self.many_to_many_targets], + 'through': [] + } + self.assertEqual(serializer.data, expected)