From 8fc9a7f267abfd466f2151c022807eede33e65b5 Mon Sep 17 00:00:00 2001 From: Warren Jin Date: Sun, 25 Jan 2015 22:12:03 -0500 Subject: [PATCH 01/22] recursive fields and a single test --- rest_framework/fields.py | 32 ++++++++++++++++++++++++++++++++ tests/test_fields.py | 30 +++++++++++++++++++++++++++++- 2 files changed, 61 insertions(+), 1 deletion(-) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index cc9410aa7..a74ce0d78 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1196,6 +1196,38 @@ class HiddenField(Field): def to_internal_value(self, data): return data +class RecursiveField(Field): + """ + A field that gets its representation from its parent. + + This method could be used to serialize a tree structure, a linked list, or + even a directed acyclic graph. As with all recursive things, it is + important to keep the base case in mind. In the case of the tree serializer + example below, the base case is a node with an empty list of children. In + the case of the list serializer below, the base case is when `next==None`. + Above all, beware of cyclical references. + + Examples: + + class TreeSerializer(self): + children = ListField(child=RecursiveField()) + + class ListSerializer(self): + next = RecursiveField(allow_null=True) + """ + + def _get_parent(self): + if hasattr(self.parent, 'child') and self.parent.child is self: + # Recursive field nested inside of some kind of composite list field + return self.parent.parent + else: + return self.parent + + def to_representation(self, value): + return self._get_parent().to_representation(value) + + def to_internal_value(self, data): + return self._get_parent().to_internal_value(data) class SerializerMethodField(Field): """ diff --git a/tests/test_fields.py b/tests/test_fields.py index 775d46184..72455b3a1 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -292,6 +292,35 @@ class TestCreateOnlyDefault: } +# Tests for RecursiveField. +# ------------------------- +class TestRecursiveField: + def setup(self): + class ListSerializer(serializers.Serializer): + name = serializers.CharField() + next = serializers.RecursiveField(allow_null=True) + self.list_serializer = ListSerializer + + class TreeSerializer(serializers.Serializer): + name = serializers.CharField() + children = serializers.ListField(child=serializers.RecursiveField()) + self.tree_serializer = TreeSerializer + + def test_serialize_list(self): + value = { + 'name':'first', + 'next': { + 'name':'second', + 'next':{ + 'name':'third', + 'next':None, + } + } + } + + serializer = self.list_serializer(value) + assert serializer.data == value + # Tests for field input and output values. # ---------------------------------------- @@ -1045,7 +1074,6 @@ class TestListField(FieldValues): ] field = serializers.ListField(child=serializers.IntegerField()) - # Tests for FieldField. # --------------------- From 6302fca16c4cab915505c42c0873d34cb5307d12 Mon Sep 17 00:00:00 2001 From: Warren Jin Date: Sun, 25 Jan 2015 22:18:20 -0500 Subject: [PATCH 02/22] returning a space --- tests/test_fields.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_fields.py b/tests/test_fields.py index 72455b3a1..a7bdad5c0 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -1074,6 +1074,7 @@ class TestListField(FieldValues): ] field = serializers.ListField(child=serializers.IntegerField()) + # Tests for FieldField. # --------------------- From 7113a120bf12de80b7719f246a714a32309fd9ba Mon Sep 17 00:00:00 2001 From: Warren Jin Date: Sun, 25 Jan 2015 22:21:55 -0500 Subject: [PATCH 03/22] spaces --- tests/test_fields.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_fields.py b/tests/test_fields.py index a7bdad5c0..6b652c92f 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -308,12 +308,12 @@ class TestRecursiveField: def test_serialize_list(self): value = { - 'name':'first', + 'name': 'first', 'next': { - 'name':'second', + 'name': 'second', 'next':{ - 'name':'third', - 'next':None, + 'name': 'third', + 'next': None, } } } From 2e8e9b8f4ceebd5c273d31ca56326304716c3d75 Mon Sep 17 00:00:00 2001 From: Warren Jin Date: Sun, 25 Jan 2015 22:33:58 -0500 Subject: [PATCH 04/22] more testing --- tests/test_fields.py | 49 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 42 insertions(+), 7 deletions(-) diff --git a/tests/test_fields.py b/tests/test_fields.py index 6b652c92f..eb0dab875 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -294,19 +294,20 @@ class TestCreateOnlyDefault: # Tests for RecursiveField. # ------------------------- + class TestRecursiveField: def setup(self): - class ListSerializer(serializers.Serializer): + class LinkSerializer(serializers.Serializer): name = serializers.CharField() next = serializers.RecursiveField(allow_null=True) - self.list_serializer = ListSerializer + self.link_serializer = LinkSerializer - class TreeSerializer(serializers.Serializer): + class NodeSerializer(serializers.Serializer): name = serializers.CharField() children = serializers.ListField(child=serializers.RecursiveField()) - self.tree_serializer = TreeSerializer + self.node_serializer = NodeSerializer - def test_serialize_list(self): + def test_link_serializer(self): value = { 'name': 'first', 'next': { @@ -318,8 +319,42 @@ class TestRecursiveField: } } - serializer = self.list_serializer(value) - assert serializer.data == value + # test serialization + serializer = self.link_serializer(value) + assert serializer.data == value, \ + 'serialized data does not match input' + + # test deserialization + serializer = self.link_serializer(data=value) + assert serializer.is_valid(), \ + 'cannot validate on deserialization' + assert serializer.validated_data == value, \ + 'deserialized data does not match input' + + def test_node_serializer(self): + value = { + 'name': 'root', + 'children': [{ + 'name': 'first child', + 'children': [], + },{ + 'name': 'second child', + 'children': [], + }] + } + + # serialization + serializer = self.node_serializer(value) + assert serializer.data == value, \ + 'serialized data does not match input' + + # deserialization + serializer = self.link_serializer(data=value) + assert serializer.is_valid(), \ + 'cannot validate on deserialization' + assert serializer.validated_data == value, \ + 'deserialized data does not match input' + # Tests for field input and output values. # ---------------------------------------- From 09edaf50edbffaa9ed7013a6688184e7d2c83625 Mon Sep 17 00:00:00 2001 From: Warren Jin Date: Sun, 25 Jan 2015 23:05:21 -0500 Subject: [PATCH 05/22] fixed tests --- rest_framework/fields.py | 11 +++++++++-- tests/test_fields.py | 12 ++++++------ 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index a74ce0d78..4366bdc3e 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1196,6 +1196,7 @@ class HiddenField(Field): def to_internal_value(self, data): return data + class RecursiveField(Field): """ A field that gets its representation from its parent. @@ -1208,13 +1209,18 @@ class RecursiveField(Field): Above all, beware of cyclical references. Examples: - + class TreeSerializer(self): children = ListField(child=RecursiveField()) class ListSerializer(self): - next = RecursiveField(allow_null=True) + next = RecursiveField() """ + + def __init__(self, *args, **kwargs): + kwargz = {'required': False} + kwargz.update(kwargs) + super(RecursiveField, self).__init__(*args, **kwargz) def _get_parent(self): if hasattr(self.parent, 'child') and self.parent.child is self: @@ -1229,6 +1235,7 @@ class RecursiveField(Field): def to_internal_value(self, data): return self._get_parent().to_internal_value(data) + class SerializerMethodField(Field): """ A read-only field that get its representation from calling a method on the diff --git a/tests/test_fields.py b/tests/test_fields.py index eb0dab875..6e7bee4c1 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -312,7 +312,7 @@ class TestRecursiveField: 'name': 'first', 'next': { 'name': 'second', - 'next':{ + 'next': { 'name': 'third', 'next': None, } @@ -327,17 +327,17 @@ class TestRecursiveField: # test deserialization serializer = self.link_serializer(data=value) assert serializer.is_valid(), \ - 'cannot validate on deserialization' + 'cannot validate on deserialization: %s' % dict(serializer.errors) assert serializer.validated_data == value, \ 'deserialized data does not match input' def test_node_serializer(self): value = { - 'name': 'root', + 'name': 'root', 'children': [{ 'name': 'first child', 'children': [], - },{ + }, { 'name': 'second child', 'children': [], }] @@ -349,9 +349,9 @@ class TestRecursiveField: 'serialized data does not match input' # deserialization - serializer = self.link_serializer(data=value) + serializer = self.node_serializer(data=value) assert serializer.is_valid(), \ - 'cannot validate on deserialization' + 'cannot validate on deserialization: %s' % dict(serializer.errors) assert serializer.validated_data == value, \ 'deserialized data does not match input' From cedfe4f33081b1ec0e90ab823d1a36a5bd21b57b Mon Sep 17 00:00:00 2001 From: Warren Jin Date: Sun, 25 Jan 2015 23:06:58 -0500 Subject: [PATCH 06/22] flake ugh --- rest_framework/fields.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 4366bdc3e..a63dbe919 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1216,7 +1216,7 @@ class RecursiveField(Field): class ListSerializer(self): next = RecursiveField() """ - + def __init__(self, *args, **kwargs): kwargz = {'required': False} kwargz.update(kwargs) @@ -1228,7 +1228,7 @@ class RecursiveField(Field): return self.parent.parent else: return self.parent - + def to_representation(self, value): return self._get_parent().to_representation(value) From c6d44789e23d463974404064a61017ccd3d25d59 Mon Sep 17 00:00:00 2001 From: Warren Jin Date: Sun, 25 Jan 2015 23:14:12 -0500 Subject: [PATCH 07/22] a comment --- rest_framework/fields.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index fc5db78cc..be0cdec97 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1287,7 +1287,7 @@ class RecursiveField(Field): children = ListField(child=RecursiveField()) class ListSerializer(self): - next = RecursiveField() + next = RecursiveField(allow_null=True) """ def __init__(self, *args, **kwargs): From 09aff6daeba9dcbc9ee3216453bb1cfd156405f2 Mon Sep 17 00:00:00 2001 From: Warren Jin Date: Tue, 27 Jan 2015 22:15:08 -0500 Subject: [PATCH 08/22] alternative approach that allows for validation --- rest_framework/fields.py | 46 ++++++++++++++++++++++++++++------------ tests/test_fields.py | 5 +++-- 2 files changed, 36 insertions(+), 15 deletions(-) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index be0cdec97..413aa5ded 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1156,6 +1156,9 @@ class ListField(Field): self.child = kwargs.pop('child', copy.deepcopy(self.child)) assert not inspect.isclass(self.child), '`child` has not been instantiated.' super(ListField, self).__init__(*args, **kwargs) + + def bind(self, field_name, parent): + super(ListField, self).bind(field_name, parent) self.child.bind(field_name='', parent=self) def get_value(self, dictionary): @@ -1290,24 +1293,41 @@ class RecursiveField(Field): next = RecursiveField(allow_null=True) """ - def __init__(self, *args, **kwargs): - kwargz = {'required': False} - kwargz.update(kwargs) - super(RecursiveField, self).__init__(*args, **kwargz) + def __init__(self, **kwargs): + field_kwargs = dict( + (key, value) + for key in kwargs + if key in inspect.getargspec(Field.__init__) + ) + super(RecursiveField, self).__init__(**field_kwargs) - def _get_parent(self): - if hasattr(self.parent, 'child') and self.parent.child is self: - # Recursive field nested inside of some kind of composite list field - return self.parent.parent + def bind(self, field_name, parent): + super(RecursiveField, self).bind(field_name, parent) + + real_dict = object.__getattribute__(self, '__dict__') + + if hasattr(parent, 'child') and parent.child is self: + proxy_class = parent.parent.__class__ else: - return self.parent + proxy_class = parent.__class__ - def to_representation(self, value): - return self._get_parent().to_representation(value) + proxy = proxy_class(**self._kwargs) + proxy.bind(field_name, parent) + real_dict['proxy'] = proxy - def to_internal_value(self, data): - return self._get_parent().to_internal_value(data) + def __getattribute__(self, name): + real_dict = object.__getattribute__(self, '__dict__') + if 'proxy' in real_dict and name != 'fields' and not (name.startswith('__') and name.endswith('__')): + return object.__getattribute__(real_dict['proxy'], name) + else: + return object.__getattribute__(self, name) + def __setattr__(self, name, value): + real_dict = object.__getattribute__(self, '__dict__') + if 'proxy' in real_dict: + setattr(real_dict['proxy'], name, value) + else: + real_dict[name] = value class SerializerMethodField(Field): """ diff --git a/tests/test_fields.py b/tests/test_fields.py index a832b75f0..3064a6056 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -300,7 +300,7 @@ class TestRecursiveField: def setup(self): class LinkSerializer(serializers.Serializer): name = serializers.CharField() - next = serializers.RecursiveField(allow_null=True) + next = serializers.RecursiveField(required=False, allow_null=True) self.link_serializer = LinkSerializer class NodeSerializer(serializers.Serializer): @@ -322,11 +322,13 @@ class TestRecursiveField: # test serialization serializer = self.link_serializer(value) + assert serializer.data == value, \ 'serialized data does not match input' # test deserialization serializer = self.link_serializer(data=value) + assert serializer.is_valid(), \ 'cannot validate on deserialization: %s' % dict(serializer.errors) assert serializer.validated_data == value, \ @@ -356,7 +358,6 @@ class TestRecursiveField: assert serializer.validated_data == value, \ 'deserialized data does not match input' - # Tests for field input and output values. # ---------------------------------------- From 25fee8e07ba66c398a9c50913d0074a2b7ac9693 Mon Sep 17 00:00:00 2001 From: Warren Jin Date: Tue, 27 Jan 2015 22:19:19 -0500 Subject: [PATCH 09/22] flake --- rest_framework/fields.py | 7 ++++--- tests/test_fields.py | 1 + 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 413aa5ded..7ef4ce3d7 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1295,8 +1295,8 @@ class RecursiveField(Field): def __init__(self, **kwargs): field_kwargs = dict( - (key, value) - for key in kwargs + (key, kwargs[key]) + for key in kwargs if key in inspect.getargspec(Field.__init__) ) super(RecursiveField, self).__init__(**field_kwargs) @@ -1305,7 +1305,7 @@ class RecursiveField(Field): super(RecursiveField, self).bind(field_name, parent) real_dict = object.__getattribute__(self, '__dict__') - + if hasattr(parent, 'child') and parent.child is self: proxy_class = parent.parent.__class__ else: @@ -1329,6 +1329,7 @@ class RecursiveField(Field): else: real_dict[name] = value + class SerializerMethodField(Field): """ A read-only field that get its representation from calling a method on the diff --git a/tests/test_fields.py b/tests/test_fields.py index 3064a6056..e5ac02f51 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -358,6 +358,7 @@ class TestRecursiveField: assert serializer.validated_data == value, \ 'deserialized data does not match input' + # Tests for field input and output values. # ---------------------------------------- From 67ed7d7933920a0d8c62e6cac0ccbec277a36f43 Mon Sep 17 00:00:00 2001 From: Warren Jin Date: Tue, 27 Jan 2015 22:32:31 -0500 Subject: [PATCH 10/22] removed unnecessary bind on the recursive field --- rest_framework/fields.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 7ef4ce3d7..7b3642fc1 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1302,8 +1302,6 @@ class RecursiveField(Field): super(RecursiveField, self).__init__(**field_kwargs) def bind(self, field_name, parent): - super(RecursiveField, self).bind(field_name, parent) - real_dict = object.__getattribute__(self, '__dict__') if hasattr(parent, 'child') and parent.child is self: From a596ace793812f755e1293d01c75e8bbe98192c5 Mon Sep 17 00:00:00 2001 From: Warren Jin Date: Tue, 27 Jan 2015 22:43:53 -0500 Subject: [PATCH 11/22] minor cleanup --- rest_framework/fields.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 7b3642fc1..4c31cd838 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1302,8 +1302,6 @@ class RecursiveField(Field): super(RecursiveField, self).__init__(**field_kwargs) def bind(self, field_name, parent): - real_dict = object.__getattribute__(self, '__dict__') - if hasattr(parent, 'child') and parent.child is self: proxy_class = parent.parent.__class__ else: @@ -1311,21 +1309,24 @@ class RecursiveField(Field): proxy = proxy_class(**self._kwargs) proxy.bind(field_name, parent) - real_dict['proxy'] = proxy + self.proxy = proxy def __getattribute__(self, name): - real_dict = object.__getattribute__(self, '__dict__') - if 'proxy' in real_dict and name != 'fields' and not (name.startswith('__') and name.endswith('__')): - return object.__getattribute__(real_dict['proxy'], name) + d = object.__getattribute__(self, '__dict__') + + # do not alias the fields parameter to prevent __repr__ from + # infinite recursion + if 'proxy' in d and name != 'fields' and name != 'proxy' and \ + not (name.startswith('__') and name.endswith('__')): + return object.__getattribute__(d['proxy'], name) else: return object.__getattribute__(self, name) def __setattr__(self, name, value): - real_dict = object.__getattribute__(self, '__dict__') - if 'proxy' in real_dict: - setattr(real_dict['proxy'], name, value) + if 'proxy' in self.__dict__ and name is not 'proxy': + setattr(self.__dict__['proxy'], name, value) else: - real_dict[name] = value + self.__dict__[name] = value class SerializerMethodField(Field): From d4c192238980500de642971c9d1746a2fb5a6b53 Mon Sep 17 00:00:00 2001 From: Warren Jin Date: Tue, 27 Jan 2015 22:47:53 -0500 Subject: [PATCH 12/22] removing unnecessary call to object.__getattribute__ --- rest_framework/fields.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 4c31cd838..3697aa998 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1318,7 +1318,7 @@ class RecursiveField(Field): # infinite recursion if 'proxy' in d and name != 'fields' and name != 'proxy' and \ not (name.startswith('__') and name.endswith('__')): - return object.__getattribute__(d['proxy'], name) + return getattr(d['proxy'], name) else: return object.__getattribute__(self, name) From e04324daab02a59410e2465c1076cdb739f591e7 Mon Sep 17 00:00:00 2001 From: Warren Jin Date: Wed, 28 Jan 2015 02:29:13 -0500 Subject: [PATCH 13/22] cleaner use of getattribute --- rest_framework/fields.py | 42 +++++++++++++++++-------- tests/test_fields.py | 66 ---------------------------------------- tests/test_recursive.py | 59 +++++++++++++++++++++++++++++++++++ 3 files changed, 88 insertions(+), 79 deletions(-) create mode 100644 tests/test_recursive.py diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 3697aa998..aad24f4ed 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1293,6 +1293,24 @@ class RecursiveField(Field): next = RecursiveField(allow_null=True) """ + # Implementation notes + # + # Only use __getattribute__ to forward the bound methods. Stop short of + # forwarding all attributes for the following reasons: + # - if you forward the __class__ attribute then deepcopy will give you back + # the wrong class + # - if you forward the fields attribute then __repr__ will enter into an + # infinite recursion + # - who knows what other infinite recursions are possible + # + # We only forward bound methods, but there are some attributes that must be + # accessible on both the RecursiveField and the proxied serializer, namely: + # field_name, read_only, default, source_attrs, write_attrs, source. As far + # as I can tell, the cleanest way to make these fields availabe without + # piecemeal forwarding them through __getattribute__ is to call bind and + # __init__ on both the RecursiveField and the proxied field using the exact + # same arguments. + def __init__(self, **kwargs): field_kwargs = dict( (key, kwargs[key]) @@ -1302,6 +1320,8 @@ class RecursiveField(Field): super(RecursiveField, self).__init__(**field_kwargs) def bind(self, field_name, parent): + super(RecursiveField, self).bind(field_name, parent) + if hasattr(parent, 'child') and parent.child is self: proxy_class = parent.parent.__class__ else: @@ -1314,19 +1334,15 @@ class RecursiveField(Field): def __getattribute__(self, name): d = object.__getattribute__(self, '__dict__') - # do not alias the fields parameter to prevent __repr__ from - # infinite recursion - if 'proxy' in d and name != 'fields' and name != 'proxy' and \ - not (name.startswith('__') and name.endswith('__')): - return getattr(d['proxy'], name) - else: - return object.__getattribute__(self, name) - - def __setattr__(self, name, value): - if 'proxy' in self.__dict__ and name is not 'proxy': - setattr(self.__dict__['proxy'], name, value) - else: - self.__dict__[name] = value + if 'proxy' in d: + try: + attr = getattr(d['proxy'], name) + + if hasattr(attr, '__self__'): + return attr + except AttributeError: + pass + return object.__getattribute__(self, name) class SerializerMethodField(Field): diff --git a/tests/test_fields.py b/tests/test_fields.py index e5ac02f51..6744cf645 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -293,72 +293,6 @@ class TestCreateOnlyDefault: } -# Tests for RecursiveField. -# ------------------------- - -class TestRecursiveField: - def setup(self): - class LinkSerializer(serializers.Serializer): - name = serializers.CharField() - next = serializers.RecursiveField(required=False, allow_null=True) - self.link_serializer = LinkSerializer - - class NodeSerializer(serializers.Serializer): - name = serializers.CharField() - children = serializers.ListField(child=serializers.RecursiveField()) - self.node_serializer = NodeSerializer - - def test_link_serializer(self): - value = { - 'name': 'first', - 'next': { - 'name': 'second', - 'next': { - 'name': 'third', - 'next': None, - } - } - } - - # test serialization - serializer = self.link_serializer(value) - - assert serializer.data == value, \ - 'serialized data does not match input' - - # test deserialization - serializer = self.link_serializer(data=value) - - assert serializer.is_valid(), \ - 'cannot validate on deserialization: %s' % dict(serializer.errors) - assert serializer.validated_data == value, \ - 'deserialized data does not match input' - - def test_node_serializer(self): - value = { - 'name': 'root', - 'children': [{ - 'name': 'first child', - 'children': [], - }, { - 'name': 'second child', - 'children': [], - }] - } - - # serialization - serializer = self.node_serializer(value) - assert serializer.data == value, \ - 'serialized data does not match input' - - # deserialization - serializer = self.node_serializer(data=value) - assert serializer.is_valid(), \ - 'cannot validate on deserialization: %s' % dict(serializer.errors) - assert serializer.validated_data == value, \ - 'deserialized data does not match input' - - # Tests for field input and output values. # ---------------------------------------- diff --git a/tests/test_recursive.py b/tests/test_recursive.py new file mode 100644 index 000000000..04a38c1c7 --- /dev/null +++ b/tests/test_recursive.py @@ -0,0 +1,59 @@ +from rest_framework import serializers + + +class LinkSerializer(serializers.Serializer): + name = serializers.CharField() + next = serializers.RecursiveField(required=False, allow_null=True) + + +class NodeSerializer(serializers.Serializer): + name = serializers.CharField() + children = serializers.ListField(child=serializers.RecursiveField()) + + +class TestRecursiveField: + @staticmethod + def serialize(serializer_class, value): + serializer = serializer_class(value) + + assert serializer.data == value, \ + 'serialized data does not match input' + + @staticmethod + def deserialize(serializer_class, data): + serializer = serializer_class(data=data) + + assert serializer.is_valid(), \ + 'cannot validate on deserialization: %s' % dict(serializer.errors) + assert serializer.validated_data == data, \ + 'deserialized data does not match input' + + def test_link_serializer(self): + value = { + 'name': 'first', + 'next': { + 'name': 'second', + 'next': { + 'name': 'third', + 'next': None, + } + } + } + + self.serialize(LinkSerializer, value) + self.deserialize(LinkSerializer, value) + + def test_node_serializer(self): + value = { + 'name': 'root', + 'children': [{ + 'name': 'first child', + 'children': [], + }, { + 'name': 'second child', + 'children': [], + }] + } + + self.serialize(NodeSerializer, value) + self.deserialize(NodeSerializer, value) From 01dfd8461f49f3b686eb65d0fb665c72108e9dde Mon Sep 17 00:00:00 2001 From: Warren Jin Date: Mon, 26 Jan 2015 20:15:33 -0500 Subject: [PATCH 14/22] more tests --- rest_framework/fields.py | 17 ++++++-- tests/test_recursive.py | 93 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 105 insertions(+), 5 deletions(-) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index aad24f4ed..06a1ea6ff 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -21,6 +21,7 @@ import collections import copy import datetime import decimal +import importlib import inspect import re import uuid @@ -1311,7 +1312,9 @@ class RecursiveField(Field): # __init__ on both the RecursiveField and the proxied field using the exact # same arguments. - def __init__(self, **kwargs): + def __init__(self, to='self', to_module=None, **kwargs): + self.to = to + self.to_module = to_module field_kwargs = dict( (key, kwargs[key]) for key in kwargs @@ -1323,9 +1326,17 @@ class RecursiveField(Field): super(RecursiveField, self).bind(field_name, parent) if hasattr(parent, 'child') and parent.child is self: - proxy_class = parent.parent.__class__ + parent_class = parent.parent.__class__ else: - proxy_class = parent.__class__ + parent_class = parent.__class__ + + if self.to == 'self': + proxy_class = parent_class + else: + ref = importlib.import_module(self.to_module or parent_class.__module__) + for part in self.to.split('.'): + ref = getattr(ref, part) + proxy_class = ref proxy = proxy_class(**self._kwargs) proxy.bind(field_name, parent) diff --git a/tests/test_recursive.py b/tests/test_recursive.py index 04a38c1c7..c675b3979 100644 --- a/tests/test_recursive.py +++ b/tests/test_recursive.py @@ -1,8 +1,7 @@ from rest_framework import serializers - class LinkSerializer(serializers.Serializer): - name = serializers.CharField() + name = serializers.CharField(max_length=25) next = serializers.RecursiveField(required=False, allow_null=True) @@ -11,6 +10,26 @@ class NodeSerializer(serializers.Serializer): children = serializers.ListField(child=serializers.RecursiveField()) +class PingSerializer(serializers.Serializer): + ping_id = serializers.IntegerField() + pong = serializers.RecursiveField('PongSerializer', required=False) + + +class PongSerializer(serializers.Serializer): + pong_id = serializers.IntegerField() + ping = PingSerializer() + + +class SillySerializer(serializers.Serializer): + name = serializers.RecursiveField( + 'CharField', 'rest_framework.fields', max_length=5) + blankable = serializers.RecursiveField( + 'CharField', 'rest_framework.fields', allow_blank=True) + nullable = serializers.RecursiveField( + 'CharField', 'rest_framework.fields', allow_null=True) + links = serializers.RecursiveField('LinkSerializer') + self = serializers.RecursiveField(required=False) + class TestRecursiveField: @staticmethod def serialize(serializer_class, value): @@ -57,3 +76,73 @@ class TestRecursiveField: self.serialize(NodeSerializer, value) self.deserialize(NodeSerializer, value) + + def test_ping_pong(self): + pong = { + 'pong_id': 4, + 'ping': { + 'ping_id': 3, + 'pong': { + 'pong_id': 2, + 'ping': { + 'ping_id': 1, + }, + }, + }, + } + self.serialize(PongSerializer, pong) + self.deserialize(PongSerializer, pong) + + def test_validation(self): + value = { + 'name': 'good', + 'blankable': '', + 'nullable': None, + 'links': { + 'name': 'something', + 'next': { + 'name': 'inner something', + } + } + } + self.serialize(SillySerializer, value) + self.deserialize(SillySerializer, value) + + max_length = { + 'name': 'too long', + 'blankable': 'not blank', + 'nullable': 'not null', + 'links': { + 'name': 'something', + } + } + serializer = SillySerializer(data=max_length) + assert not serializer.is_valid(), \ + 'validation should fail due to name too long' + + nulled_out = { + 'name': 'good', + 'blankable': None, + 'nullable': 'not null', + 'links': { + 'name': 'something', + } + } + serializer = SillySerializer(data=nulled_out) + assert not serializer.is_valid(), \ + 'validation should fail due to null field' + + way_too_long = { + 'name': 'good', + 'blankable': '', + 'nullable': None, + 'links': { + 'name': 'something', + 'next': { + 'name': 'inner something that is much too long', + } + } + } + serializer = SillySerializer(data=way_too_long) + assert not serializer.is_valid(), \ + 'validation should fail on inner link validation' From 672df63599bfdc486e2767edb83b173b83f5a690 Mon Sep 17 00:00:00 2001 From: Warren Jin Date: Wed, 28 Jan 2015 10:17:37 -0500 Subject: [PATCH 15/22] whitespace --- tests/test_recursive.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_recursive.py b/tests/test_recursive.py index c675b3979..c48f4925e 100644 --- a/tests/test_recursive.py +++ b/tests/test_recursive.py @@ -30,6 +30,7 @@ class SillySerializer(serializers.Serializer): links = serializers.RecursiveField('LinkSerializer') self = serializers.RecursiveField(required=False) + class TestRecursiveField: @staticmethod def serialize(serializer_class, value): From a63b9fde2ef690e79875b985ccc3315b422e8351 Mon Sep 17 00:00:00 2001 From: Warren Jin Date: Wed, 28 Jan 2015 10:20:06 -0500 Subject: [PATCH 16/22] flake --- rest_framework/fields.py | 22 ++-------------------- tests/test_recursive.py | 3 ++- 2 files changed, 4 insertions(+), 21 deletions(-) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 06a1ea6ff..cce597771 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1294,24 +1294,6 @@ class RecursiveField(Field): next = RecursiveField(allow_null=True) """ - # Implementation notes - # - # Only use __getattribute__ to forward the bound methods. Stop short of - # forwarding all attributes for the following reasons: - # - if you forward the __class__ attribute then deepcopy will give you back - # the wrong class - # - if you forward the fields attribute then __repr__ will enter into an - # infinite recursion - # - who knows what other infinite recursions are possible - # - # We only forward bound methods, but there are some attributes that must be - # accessible on both the RecursiveField and the proxied serializer, namely: - # field_name, read_only, default, source_attrs, write_attrs, source. As far - # as I can tell, the cleanest way to make these fields availabe without - # piecemeal forwarding them through __getattribute__ is to call bind and - # __init__ on both the RecursiveField and the proxied field using the exact - # same arguments. - def __init__(self, to='self', to_module=None, **kwargs): self.to = to self.to_module = to_module @@ -1336,7 +1318,7 @@ class RecursiveField(Field): ref = importlib.import_module(self.to_module or parent_class.__module__) for part in self.to.split('.'): ref = getattr(ref, part) - proxy_class = ref + proxy_class = ref proxy = proxy_class(**self._kwargs) proxy.bind(field_name, parent) @@ -1348,7 +1330,7 @@ class RecursiveField(Field): if 'proxy' in d: try: attr = getattr(d['proxy'], name) - + if hasattr(attr, '__self__'): return attr except AttributeError: diff --git a/tests/test_recursive.py b/tests/test_recursive.py index c48f4925e..b343bd727 100644 --- a/tests/test_recursive.py +++ b/tests/test_recursive.py @@ -1,5 +1,6 @@ from rest_framework import serializers + class LinkSerializer(serializers.Serializer): name = serializers.CharField(max_length=25) next = serializers.RecursiveField(required=False, allow_null=True) @@ -47,7 +48,7 @@ class TestRecursiveField: 'cannot validate on deserialization: %s' % dict(serializer.errors) assert serializer.validated_data == data, \ 'deserialized data does not match input' - + def test_link_serializer(self): value = { 'name': 'first', From 297012975178ceaedded35bf4c69f96d4196347e Mon Sep 17 00:00:00 2001 From: Warren Jin Date: Thu, 29 Jan 2015 14:07:58 -0500 Subject: [PATCH 17/22] explicitly proxied fields --- rest_framework/fields.py | 70 ++++++++++++++++++++++++---------------- tests/test_recursive.py | 6 ++-- 2 files changed, 45 insertions(+), 31 deletions(-) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index cce597771..c9110aa37 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -4,7 +4,7 @@ from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ValidationError as DjangoValidationError from django.core.validators import RegexValidator from django.forms import ImageField as DjangoImageField -from django.utils import six, timezone +from django.utils import six, timezone, importlib from django.utils.dateparse import parse_date, parse_datetime, parse_time from django.utils.encoding import is_protected_type, smart_text from django.utils.translation import ugettext_lazy as _ @@ -21,7 +21,6 @@ import collections import copy import datetime import decimal -import importlib import inspect import re import uuid @@ -1294,47 +1293,62 @@ class RecursiveField(Field): next = RecursiveField(allow_null=True) """ - def __init__(self, to='self', to_module=None, **kwargs): + # This list of attributes determined by the attributes that + # `rest_framework.serializers` calls to on a field object + PROXIED_ATTRS = ( + # bound fields + 'get_value', + 'get_initial', + 'run_validation', + 'get_attribute', + 'to_representation', + + # attributes + 'field_name', + 'source', + 'read_only', + 'default', + 'source_attrs', + 'write_only', + ) + + def __init__(self, to=None, **kwargs): self.to = to - self.to_module = to_module - field_kwargs = dict( - (key, kwargs[key]) - for key in kwargs - if key in inspect.getargspec(Field.__init__) - ) - super(RecursiveField, self).__init__(**field_kwargs) + self.kwargs = kwargs def bind(self, field_name, parent): - super(RecursiveField, self).bind(field_name, parent) - if hasattr(parent, 'child') and parent.child is self: parent_class = parent.parent.__class__ else: parent_class = parent.__class__ - if self.to == 'self': - proxy_class = parent_class + if self.to is None: + proxied_class = parent_class else: - ref = importlib.import_module(self.to_module or parent_class.__module__) - for part in self.to.split('.'): - ref = getattr(ref, part) - proxy_class = ref + try: + module_name, class_name = self.to.rsplit('.', 1) + except ValueError: + module_name, class_name = parent_class.__module__, self.to - proxy = proxy_class(**self._kwargs) - proxy.bind(field_name, parent) - self.proxy = proxy + try: + proxied_class = getattr( + importlib.import_module(module_name), class_name) + except Exception as e: + raise ImportError( + 'could not locate serializer %s' % self.to, e) + + proxied = proxied_class(**self.kwargs) + proxied.bind(field_name, parent) + self.proxied = proxied def __getattribute__(self, name): - d = object.__getattribute__(self, '__dict__') - - if 'proxy' in d: + if name in RecursiveField.PROXIED_ATTRS: try: - attr = getattr(d['proxy'], name) - - if hasattr(attr, '__self__'): - return attr + proxied = object.__getattribute__(self, 'proxied') + return getattr(proxied, name) except AttributeError: pass + return object.__getattribute__(self, name) diff --git a/tests/test_recursive.py b/tests/test_recursive.py index b343bd727..f3e60d2b2 100644 --- a/tests/test_recursive.py +++ b/tests/test_recursive.py @@ -23,11 +23,11 @@ class PongSerializer(serializers.Serializer): class SillySerializer(serializers.Serializer): name = serializers.RecursiveField( - 'CharField', 'rest_framework.fields', max_length=5) + 'rest_framework.fields.CharField', max_length=5) blankable = serializers.RecursiveField( - 'CharField', 'rest_framework.fields', allow_blank=True) + 'rest_framework.fields.CharField', allow_blank=True) nullable = serializers.RecursiveField( - 'CharField', 'rest_framework.fields', allow_null=True) + 'rest_framework.fields.CharField', allow_null=True) links = serializers.RecursiveField('LinkSerializer') self = serializers.RecursiveField(required=False) From 0d85b5575ee79afd6db8a492d63ef2816c78056d Mon Sep 17 00:00:00 2001 From: Warren Jin Date: Thu, 29 Jan 2015 14:11:03 -0500 Subject: [PATCH 18/22] flake whitespace --- rest_framework/fields.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index c9110aa37..c32f4a754 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1296,14 +1296,14 @@ class RecursiveField(Field): # This list of attributes determined by the attributes that # `rest_framework.serializers` calls to on a field object PROXIED_ATTRS = ( - # bound fields + # bound fields 'get_value', 'get_initial', 'run_validation', 'get_attribute', 'to_representation', - - # attributes + + # attributes 'field_name', 'source', 'read_only', @@ -1348,7 +1348,7 @@ class RecursiveField(Field): return getattr(proxied, name) except AttributeError: pass - + return object.__getattribute__(self, name) From 3b07c1dfc6d40e3f7d7b71d388b39e40725c232e Mon Sep 17 00:00:00 2001 From: Warren Jin Date: Thu, 29 Jan 2015 14:20:42 -0500 Subject: [PATCH 19/22] required not actually necessary if allow_null present --- tests/test_recursive.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_recursive.py b/tests/test_recursive.py index f3e60d2b2..a4524bbb7 100644 --- a/tests/test_recursive.py +++ b/tests/test_recursive.py @@ -3,7 +3,7 @@ from rest_framework import serializers class LinkSerializer(serializers.Serializer): name = serializers.CharField(max_length=25) - next = serializers.RecursiveField(required=False, allow_null=True) + next = serializers.RecursiveField(allow_null=True) class NodeSerializer(serializers.Serializer): @@ -104,6 +104,7 @@ class TestRecursiveField: 'name': 'something', 'next': { 'name': 'inner something', + 'next': None, } } } @@ -116,6 +117,7 @@ class TestRecursiveField: 'nullable': 'not null', 'links': { 'name': 'something', + 'next': None, } } serializer = SillySerializer(data=max_length) @@ -128,6 +130,7 @@ class TestRecursiveField: 'nullable': 'not null', 'links': { 'name': 'something', + 'next': None, } } serializer = SillySerializer(data=nulled_out) @@ -142,6 +145,7 @@ class TestRecursiveField: 'name': 'something', 'next': { 'name': 'inner something that is much too long', + 'next': None, } } } From 3c7bd7f7f5e52391be0bdcb9d6e5938aaac812df Mon Sep 17 00:00:00 2001 From: Warren Jin Date: Thu, 29 Jan 2015 14:31:00 -0500 Subject: [PATCH 20/22] comments --- rest_framework/fields.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index c32f4a754..7223c46b4 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1296,7 +1296,7 @@ class RecursiveField(Field): # This list of attributes determined by the attributes that # `rest_framework.serializers` calls to on a field object PROXIED_ATTRS = ( - # bound fields + # methods 'get_value', 'get_initial', 'run_validation', @@ -1313,13 +1313,22 @@ class RecursiveField(Field): ) def __init__(self, to=None, **kwargs): + """ + arguments: + to - `None`, the name of another serializer defined in the same module + as this serializer, or the fully qualified import path to another + serializer. e.g. `ExampleSerializer` or + `path.to.module.ExampleSerializer` + """ self.to = to self.kwargs = kwargs def bind(self, field_name, parent): if hasattr(parent, 'child') and parent.child is self: + # RecursiveField nested inside of a ListField parent_class = parent.parent.__class__ else: + # RecursiveField directly inside a Serializer parent_class = parent.__class__ if self.to is None: @@ -1337,6 +1346,7 @@ class RecursiveField(Field): raise ImportError( 'could not locate serializer %s' % self.to, e) + # Create a new serializer instance and proxy it proxied = proxied_class(**self.kwargs) proxied.bind(field_name, parent) self.proxied = proxied From 85714ee3d9effafde00e1b80badd664a29c096bb Mon Sep 17 00:00:00 2001 From: Warren Jin Date: Tue, 3 Feb 2015 08:56:20 -0500 Subject: [PATCH 21/22] modelserializer support --- rest_framework/fields.py | 9 +++++++++ tests/test_recursive.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 7223c46b4..f15e0131d 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -1323,6 +1323,15 @@ class RecursiveField(Field): self.to = to self.kwargs = kwargs + # Need to properly initialize by calling super-constructor for + # ModelSerializers + super_kwargs = dict( + (key, kwargs[key]) + for key in kwargs + if key in inspect.getargspec(Field.__init__) + ) + super(RecursiveField, self).__init__(**super_kwargs) + def bind(self, field_name, parent): if hasattr(parent, 'child') and parent.child is self: # RecursiveField nested inside of a ListField diff --git a/tests/test_recursive.py b/tests/test_recursive.py index a4524bbb7..e360f0e83 100644 --- a/tests/test_recursive.py +++ b/tests/test_recursive.py @@ -1,3 +1,4 @@ +from django.db import models from rest_framework import serializers @@ -32,6 +33,19 @@ class SillySerializer(serializers.Serializer): self = serializers.RecursiveField(required=False) +class RecursiveModel(models.Model): + name = models.CharField(max_length=255) + parent = models.ForeignKey('self', null=True) + + +class RecursiveModelSerializer(serializers.ModelSerializer): + parent = serializers.RecursiveField(allow_null=True) + + class Meta: + model = RecursiveModel + fields = ('name', 'parent') + + class TestRecursiveField: @staticmethod def serialize(serializer_class, value): @@ -152,3 +166,23 @@ class TestRecursiveField: serializer = SillySerializer(data=way_too_long) assert not serializer.is_valid(), \ 'validation should fail on inner link validation' + + def test_model_serializer(self): + one = RecursiveModel(name='one') + two = RecursiveModel(name='two', parent=one) + + #serialization + representation = { + 'name': 'two', + 'parent': { + 'name': 'one', + 'parent': None, + } + } + + s = RecursiveModelSerializer(two) + assert s.data == representation + + #deserialization + self.deserialize(RecursiveModelSerializer, representation) + From 8fadf54ff989cbefdaf5e9f584260613ae208afd Mon Sep 17 00:00:00 2001 From: Warren Jin Date: Tue, 3 Feb 2015 08:58:25 -0500 Subject: [PATCH 22/22] flake --- tests/test_recursive.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/test_recursive.py b/tests/test_recursive.py index e360f0e83..fe8ec2f16 100644 --- a/tests/test_recursive.py +++ b/tests/test_recursive.py @@ -171,7 +171,7 @@ class TestRecursiveField: one = RecursiveModel(name='one') two = RecursiveModel(name='two', parent=one) - #serialization + # serialization representation = { 'name': 'two', 'parent': { @@ -183,6 +183,5 @@ class TestRecursiveField: s = RecursiveModelSerializer(two) assert s.data == representation - #deserialization + # deserialization self.deserialize(RecursiveModelSerializer, representation) -