cleaner use of getattribute

This commit is contained in:
Warren Jin 2015-01-28 02:29:13 -05:00
parent d4c1922389
commit e04324daab
3 changed files with 88 additions and 79 deletions

View File

@ -1293,6 +1293,24 @@ class RecursiveField(Field):
next = RecursiveField(allow_null=True)
"""
# Implementation notes
#
# Only use __getattribute__ to forward the bound methods. Stop short of
# forwarding all attributes for the following reasons:
# - if you forward the __class__ attribute then deepcopy will give you back
# the wrong class
# - if you forward the fields attribute then __repr__ will enter into an
# infinite recursion
# - who knows what other infinite recursions are possible
#
# We only forward bound methods, but there are some attributes that must be
# accessible on both the RecursiveField and the proxied serializer, namely:
# field_name, read_only, default, source_attrs, write_attrs, source. As far
# as I can tell, the cleanest way to make these fields availabe without
# piecemeal forwarding them through __getattribute__ is to call bind and
# __init__ on both the RecursiveField and the proxied field using the exact
# same arguments.
def __init__(self, **kwargs):
field_kwargs = dict(
(key, kwargs[key])
@ -1302,6 +1320,8 @@ class RecursiveField(Field):
super(RecursiveField, self).__init__(**field_kwargs)
def bind(self, field_name, parent):
super(RecursiveField, self).bind(field_name, parent)
if hasattr(parent, 'child') and parent.child is self:
proxy_class = parent.parent.__class__
else:
@ -1314,19 +1334,15 @@ class RecursiveField(Field):
def __getattribute__(self, name):
d = object.__getattribute__(self, '__dict__')
# do not alias the fields parameter to prevent __repr__ from
# infinite recursion
if 'proxy' in d and name != 'fields' and name != 'proxy' and \
not (name.startswith('__') and name.endswith('__')):
return getattr(d['proxy'], name)
else:
return object.__getattribute__(self, name)
if 'proxy' in d:
try:
attr = getattr(d['proxy'], name)
def __setattr__(self, name, value):
if 'proxy' in self.__dict__ and name is not 'proxy':
setattr(self.__dict__['proxy'], name, value)
else:
self.__dict__[name] = value
if hasattr(attr, '__self__'):
return attr
except AttributeError:
pass
return object.__getattribute__(self, name)
class SerializerMethodField(Field):

View File

@ -293,72 +293,6 @@ class TestCreateOnlyDefault:
}
# Tests for RecursiveField.
# -------------------------
class TestRecursiveField:
def setup(self):
class LinkSerializer(serializers.Serializer):
name = serializers.CharField()
next = serializers.RecursiveField(required=False, allow_null=True)
self.link_serializer = LinkSerializer
class NodeSerializer(serializers.Serializer):
name = serializers.CharField()
children = serializers.ListField(child=serializers.RecursiveField())
self.node_serializer = NodeSerializer
def test_link_serializer(self):
value = {
'name': 'first',
'next': {
'name': 'second',
'next': {
'name': 'third',
'next': None,
}
}
}
# test serialization
serializer = self.link_serializer(value)
assert serializer.data == value, \
'serialized data does not match input'
# test deserialization
serializer = self.link_serializer(data=value)
assert serializer.is_valid(), \
'cannot validate on deserialization: %s' % dict(serializer.errors)
assert serializer.validated_data == value, \
'deserialized data does not match input'
def test_node_serializer(self):
value = {
'name': 'root',
'children': [{
'name': 'first child',
'children': [],
}, {
'name': 'second child',
'children': [],
}]
}
# serialization
serializer = self.node_serializer(value)
assert serializer.data == value, \
'serialized data does not match input'
# deserialization
serializer = self.node_serializer(data=value)
assert serializer.is_valid(), \
'cannot validate on deserialization: %s' % dict(serializer.errors)
assert serializer.validated_data == value, \
'deserialized data does not match input'
# Tests for field input and output values.
# ----------------------------------------

59
tests/test_recursive.py Normal file
View File

@ -0,0 +1,59 @@
from rest_framework import serializers
class LinkSerializer(serializers.Serializer):
name = serializers.CharField()
next = serializers.RecursiveField(required=False, allow_null=True)
class NodeSerializer(serializers.Serializer):
name = serializers.CharField()
children = serializers.ListField(child=serializers.RecursiveField())
class TestRecursiveField:
@staticmethod
def serialize(serializer_class, value):
serializer = serializer_class(value)
assert serializer.data == value, \
'serialized data does not match input'
@staticmethod
def deserialize(serializer_class, data):
serializer = serializer_class(data=data)
assert serializer.is_valid(), \
'cannot validate on deserialization: %s' % dict(serializer.errors)
assert serializer.validated_data == data, \
'deserialized data does not match input'
def test_link_serializer(self):
value = {
'name': 'first',
'next': {
'name': 'second',
'next': {
'name': 'third',
'next': None,
}
}
}
self.serialize(LinkSerializer, value)
self.deserialize(LinkSerializer, value)
def test_node_serializer(self):
value = {
'name': 'root',
'children': [{
'name': 'first child',
'children': [],
}, {
'name': 'second child',
'children': [],
}]
}
self.serialize(NodeSerializer, value)
self.deserialize(NodeSerializer, value)