First pass on ManyRelation

This commit is contained in:
Tom Christie 2014-09-18 14:23:00 +01:00
parent 3bc628edc0
commit 9fdb2280d1
3 changed files with 89 additions and 9 deletions

View File

@ -10,7 +10,6 @@ from django.utils.translation import ugettext_lazy as _
class RelatedField(Field):
def __init__(self, **kwargs):
self.queryset = kwargs.pop('queryset', None)
self.many = kwargs.pop('many', False)
assert self.queryset is not None or kwargs.get('read_only', None), (
'Relational field must provide a `queryset` argument, '
'or set read_only=`True`.'
@ -21,6 +20,13 @@ class RelatedField(Field):
)
super(RelatedField, self).__init__(**kwargs)
def __new__(cls, *args, **kwargs):
# We override this method in order to automagically create
# `ManyRelation` classes instead when `many=True` is set.
if kwargs.pop('many', False):
return ManyRelation(child_relation=cls(*args, **kwargs))
return super(RelatedField, cls).__new__(cls, *args, **kwargs)
def get_queryset(self):
queryset = self.queryset
if isinstance(queryset, QuerySet):
@ -216,3 +222,37 @@ class SlugRelatedField(RelatedField):
def to_representation(self, obj):
return getattr(obj, self.slug_field)
class ManyRelation(Field):
"""
Relationships with `many=True` transparently get coerced into instead being
a ManyRelation with a child relationship.
The `ManyRelation` class is responsible for handling iterating through
the values and passing each one to the child relationship.
You shouldn't need to be using this class directly yourself.
"""
def __init__(self, child_relation=None, *args, **kwargs):
self.child_relation = child_relation
assert child_relation is not None, '`child_relation` is a required argument.'
super(ManyRelation, self).__init__(*args, **kwargs)
def bind(self, field_name, parent, root):
# ManyRelation needs to provide the current context to the child relation.
super(ManyRelation, self).bind(field_name, parent, root)
self.child_relation.bind(field_name, parent, root)
def to_internal_value(self, data):
return [
self.child_relation.to_internal_value(item)
for item in data
]
def to_representation(self, obj):
return [
self.child_relation.to_representation(value)
for value in obj.all()
]

View File

@ -73,6 +73,8 @@ def serializer_repr(serializer, indent, force_many=None):
ret += serializer_repr(field, indent + 1)
elif hasattr(field, 'child'):
ret += list_repr(field, indent + 1)
elif hasattr(field, 'child_relation'):
ret += field_repr(field.child_relation, force_many=field.child_relation)
else:
ret += field_repr(field)
return ret

View File

@ -199,7 +199,7 @@ class RelationalModel(models.Model):
class TestRelationalFieldMappings(TestCase):
def test_flat_relational_fields(self):
def test_pk_relations(self):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = RelationalModel
@ -214,7 +214,7 @@ class TestRelationalFieldMappings(TestCase):
""")
self.assertEqual(repr(TestSerializer()), expected)
def test_nested_relational_fields(self):
def test_nested_relations(self):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = RelationalModel
@ -238,7 +238,7 @@ class TestRelationalFieldMappings(TestCase):
""")
self.assertEqual(repr(TestSerializer()), expected)
def test_flat_hyperlinked_fields(self):
def test_hyperlinked_relations(self):
class TestSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = RelationalModel
@ -253,7 +253,7 @@ class TestRelationalFieldMappings(TestCase):
""")
self.assertEqual(repr(TestSerializer()), expected)
def test_nested_hyperlinked_fields(self):
def test_nested_hyperlinked_relations(self):
class TestSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = RelationalModel
@ -277,7 +277,7 @@ class TestRelationalFieldMappings(TestCase):
""")
self.assertEqual(repr(TestSerializer()), expected)
def test_flat_reverse_foreign_key(self):
def test_pk_reverse_foreign_key(self):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = ForeignKeyTargetModel
@ -291,7 +291,7 @@ class TestRelationalFieldMappings(TestCase):
""")
self.assertEqual(repr(TestSerializer()), expected)
def test_flat_reverse_one_to_one(self):
def test_pk_reverse_one_to_one(self):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = OneToOneTargetModel
@ -305,7 +305,7 @@ class TestRelationalFieldMappings(TestCase):
""")
self.assertEqual(repr(TestSerializer()), expected)
def test_flat_reverse_many_to_many(self):
def test_pk_reverse_many_to_many(self):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = ManyToManyTargetModel
@ -319,7 +319,7 @@ class TestRelationalFieldMappings(TestCase):
""")
self.assertEqual(repr(TestSerializer()), expected)
def test_flat_reverse_through(self):
def test_pk_reverse_through(self):
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = ThroughTargetModel
@ -332,3 +332,41 @@ class TestRelationalFieldMappings(TestCase):
reverse_through = PrimaryKeyRelatedField(many=True, read_only=True)
""")
self.assertEqual(repr(TestSerializer()), expected)
class TestIntegration(TestCase):
def setUp(self):
self.foreign_key_target = ForeignKeyTargetModel.objects.create(
name='foreign_key'
)
self.one_to_one_target = OneToOneTargetModel.objects.create(
name='one_to_one'
)
self.many_to_many_targets = [
ManyToManyTargetModel.objects.create(
name='many_to_many (%d)' % idx
) for idx in range(3)
]
self.instance = RelationalModel.objects.create(
foreign_key=self.foreign_key_target,
one_to_one=self.one_to_one_target,
)
self.instance.many_to_many = self.many_to_many_targets
self.instance.save()
class TestSerializer(serializers.ModelSerializer):
class Meta:
model = RelationalModel
self.serializer_cls = TestSerializer
def test_pk_relationship_representations(self):
serializer = self.serializer_cls(self.instance)
expected = {
'id': self.instance.pk,
'foreign_key': self.foreign_key_target.pk,
'one_to_one': self.one_to_one_target.pk,
'many_to_many': [item.pk for item in self.many_to_many_targets],
'through': []
}
self.assertEqual(serializer.data, expected)