From 09aff6daeba9dcbc9ee3216453bb1cfd156405f2 Mon Sep 17 00:00:00 2001 From: Warren Jin Date: Tue, 27 Jan 2015 22:15:08 -0500 Subject: [PATCH] alternative approach that allows for validation --- rest_framework/fields.py | 46 ++++++++++++++++++++++++++++------------ tests/test_fields.py | 5 +++-- 2 files changed, 36 insertions(+), 15 deletions(-) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index be0cdec97..413aa5ded 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1156,6 +1156,9 @@ class ListField(Field): self.child = kwargs.pop('child', copy.deepcopy(self.child)) assert not inspect.isclass(self.child), '`child` has not been instantiated.' super(ListField, self).__init__(*args, **kwargs) + + def bind(self, field_name, parent): + super(ListField, self).bind(field_name, parent) self.child.bind(field_name='', parent=self) def get_value(self, dictionary): @@ -1290,24 +1293,41 @@ class RecursiveField(Field): next = RecursiveField(allow_null=True) """ - def __init__(self, *args, **kwargs): - kwargz = {'required': False} - kwargz.update(kwargs) - super(RecursiveField, self).__init__(*args, **kwargz) + def __init__(self, **kwargs): + field_kwargs = dict( + (key, value) + for key in kwargs + if key in inspect.getargspec(Field.__init__) + ) + super(RecursiveField, self).__init__(**field_kwargs) - def _get_parent(self): - if hasattr(self.parent, 'child') and self.parent.child is self: - # Recursive field nested inside of some kind of composite list field - return self.parent.parent + def bind(self, field_name, parent): + super(RecursiveField, self).bind(field_name, parent) + + real_dict = object.__getattribute__(self, '__dict__') + + if hasattr(parent, 'child') and parent.child is self: + proxy_class = parent.parent.__class__ else: - return self.parent + proxy_class = parent.__class__ - def to_representation(self, value): - return self._get_parent().to_representation(value) + proxy = proxy_class(**self._kwargs) + proxy.bind(field_name, parent) + real_dict['proxy'] = proxy - def to_internal_value(self, data): - return self._get_parent().to_internal_value(data) + def __getattribute__(self, name): + real_dict = object.__getattribute__(self, '__dict__') + if 'proxy' in real_dict and name != 'fields' and not (name.startswith('__') and name.endswith('__')): + return object.__getattribute__(real_dict['proxy'], name) + else: + return object.__getattribute__(self, name) + def __setattr__(self, name, value): + real_dict = object.__getattribute__(self, '__dict__') + if 'proxy' in real_dict: + setattr(real_dict['proxy'], name, value) + else: + real_dict[name] = value class SerializerMethodField(Field): """ diff --git a/tests/test_fields.py b/tests/test_fields.py index a832b75f0..3064a6056 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -300,7 +300,7 @@ class TestRecursiveField: def setup(self): class LinkSerializer(serializers.Serializer): name = serializers.CharField() - next = serializers.RecursiveField(allow_null=True) + next = serializers.RecursiveField(required=False, allow_null=True) self.link_serializer = LinkSerializer class NodeSerializer(serializers.Serializer): @@ -322,11 +322,13 @@ class TestRecursiveField: # test serialization serializer = self.link_serializer(value) + assert serializer.data == value, \ 'serialized data does not match input' # test deserialization serializer = self.link_serializer(data=value) + assert serializer.is_valid(), \ 'cannot validate on deserialization: %s' % dict(serializer.errors) assert serializer.validated_data == value, \ @@ -356,7 +358,6 @@ class TestRecursiveField: assert serializer.validated_data == value, \ 'deserialized data does not match input' - # Tests for field input and output values. # ----------------------------------------