diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 3697aa998..aad24f4ed 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1293,6 +1293,24 @@ class RecursiveField(Field): next = RecursiveField(allow_null=True) """ + # Implementation notes + # + # Only use __getattribute__ to forward the bound methods. Stop short of + # forwarding all attributes for the following reasons: + # - if you forward the __class__ attribute then deepcopy will give you back + # the wrong class + # - if you forward the fields attribute then __repr__ will enter into an + # infinite recursion + # - who knows what other infinite recursions are possible + # + # We only forward bound methods, but there are some attributes that must be + # accessible on both the RecursiveField and the proxied serializer, namely: + # field_name, read_only, default, source_attrs, write_attrs, source. As far + # as I can tell, the cleanest way to make these fields availabe without + # piecemeal forwarding them through __getattribute__ is to call bind and + # __init__ on both the RecursiveField and the proxied field using the exact + # same arguments. + def __init__(self, **kwargs): field_kwargs = dict( (key, kwargs[key]) @@ -1302,6 +1320,8 @@ class RecursiveField(Field): super(RecursiveField, self).__init__(**field_kwargs) def bind(self, field_name, parent): + super(RecursiveField, self).bind(field_name, parent) + if hasattr(parent, 'child') and parent.child is self: proxy_class = parent.parent.__class__ else: @@ -1314,19 +1334,15 @@ class RecursiveField(Field): def __getattribute__(self, name): d = object.__getattribute__(self, '__dict__') - # do not alias the fields parameter to prevent __repr__ from - # infinite recursion - if 'proxy' in d and name != 'fields' and name != 'proxy' and \ - not (name.startswith('__') and name.endswith('__')): - return getattr(d['proxy'], name) - else: - return object.__getattribute__(self, name) - - def __setattr__(self, name, value): - if 'proxy' in self.__dict__ and name is not 'proxy': - setattr(self.__dict__['proxy'], name, value) - else: - self.__dict__[name] = value + if 'proxy' in d: + try: + attr = getattr(d['proxy'], name) + + if hasattr(attr, '__self__'): + return attr + except AttributeError: + pass + return object.__getattribute__(self, name) class SerializerMethodField(Field): diff --git a/tests/test_fields.py b/tests/test_fields.py index e5ac02f51..6744cf645 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -293,72 +293,6 @@ class TestCreateOnlyDefault: } -# Tests for RecursiveField. -# ------------------------- - -class TestRecursiveField: - def setup(self): - class LinkSerializer(serializers.Serializer): - name = serializers.CharField() - next = serializers.RecursiveField(required=False, allow_null=True) - self.link_serializer = LinkSerializer - - class NodeSerializer(serializers.Serializer): - name = serializers.CharField() - children = serializers.ListField(child=serializers.RecursiveField()) - self.node_serializer = NodeSerializer - - def test_link_serializer(self): - value = { - 'name': 'first', - 'next': { - 'name': 'second', - 'next': { - 'name': 'third', - 'next': None, - } - } - } - - # test serialization - serializer = self.link_serializer(value) - - assert serializer.data == value, \ - 'serialized data does not match input' - - # test deserialization - serializer = self.link_serializer(data=value) - - assert serializer.is_valid(), \ - 'cannot validate on deserialization: %s' % dict(serializer.errors) - assert serializer.validated_data == value, \ - 'deserialized data does not match input' - - def test_node_serializer(self): - value = { - 'name': 'root', - 'children': [{ - 'name': 'first child', - 'children': [], - }, { - 'name': 'second child', - 'children': [], - }] - } - - # serialization - serializer = self.node_serializer(value) - assert serializer.data == value, \ - 'serialized data does not match input' - - # deserialization - serializer = self.node_serializer(data=value) - assert serializer.is_valid(), \ - 'cannot validate on deserialization: %s' % dict(serializer.errors) - assert serializer.validated_data == value, \ - 'deserialized data does not match input' - - # Tests for field input and output values. # ---------------------------------------- diff --git a/tests/test_recursive.py b/tests/test_recursive.py new file mode 100644 index 000000000..04a38c1c7 --- /dev/null +++ b/tests/test_recursive.py @@ -0,0 +1,59 @@ +from rest_framework import serializers + + +class LinkSerializer(serializers.Serializer): + name = serializers.CharField() + next = serializers.RecursiveField(required=False, allow_null=True) + + +class NodeSerializer(serializers.Serializer): + name = serializers.CharField() + children = serializers.ListField(child=serializers.RecursiveField()) + + +class TestRecursiveField: + @staticmethod + def serialize(serializer_class, value): + serializer = serializer_class(value) + + assert serializer.data == value, \ + 'serialized data does not match input' + + @staticmethod + def deserialize(serializer_class, data): + serializer = serializer_class(data=data) + + assert serializer.is_valid(), \ + 'cannot validate on deserialization: %s' % dict(serializer.errors) + assert serializer.validated_data == data, \ + 'deserialized data does not match input' + + def test_link_serializer(self): + value = { + 'name': 'first', + 'next': { + 'name': 'second', + 'next': { + 'name': 'third', + 'next': None, + } + } + } + + self.serialize(LinkSerializer, value) + self.deserialize(LinkSerializer, value) + + def test_node_serializer(self): + value = { + 'name': 'root', + 'children': [{ + 'name': 'first child', + 'children': [], + }, { + 'name': 'second child', + 'children': [], + }] + } + + self.serialize(NodeSerializer, value) + self.deserialize(NodeSerializer, value)