mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-08-05 21:10:13 +03:00
cleaner use of getattribute
This commit is contained in:
parent
d4c1922389
commit
e04324daab
|
@ -1293,6 +1293,24 @@ class RecursiveField(Field):
|
|||
next = RecursiveField(allow_null=True)
|
||||
"""
|
||||
|
||||
# Implementation notes
|
||||
#
|
||||
# Only use __getattribute__ to forward the bound methods. Stop short of
|
||||
# forwarding all attributes for the following reasons:
|
||||
# - if you forward the __class__ attribute then deepcopy will give you back
|
||||
# the wrong class
|
||||
# - if you forward the fields attribute then __repr__ will enter into an
|
||||
# infinite recursion
|
||||
# - who knows what other infinite recursions are possible
|
||||
#
|
||||
# We only forward bound methods, but there are some attributes that must be
|
||||
# accessible on both the RecursiveField and the proxied serializer, namely:
|
||||
# field_name, read_only, default, source_attrs, write_attrs, source. As far
|
||||
# as I can tell, the cleanest way to make these fields availabe without
|
||||
# piecemeal forwarding them through __getattribute__ is to call bind and
|
||||
# __init__ on both the RecursiveField and the proxied field using the exact
|
||||
# same arguments.
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
field_kwargs = dict(
|
||||
(key, kwargs[key])
|
||||
|
@ -1302,6 +1320,8 @@ class RecursiveField(Field):
|
|||
super(RecursiveField, self).__init__(**field_kwargs)
|
||||
|
||||
def bind(self, field_name, parent):
|
||||
super(RecursiveField, self).bind(field_name, parent)
|
||||
|
||||
if hasattr(parent, 'child') and parent.child is self:
|
||||
proxy_class = parent.parent.__class__
|
||||
else:
|
||||
|
@ -1314,19 +1334,15 @@ class RecursiveField(Field):
|
|||
def __getattribute__(self, name):
|
||||
d = object.__getattribute__(self, '__dict__')
|
||||
|
||||
# do not alias the fields parameter to prevent __repr__ from
|
||||
# infinite recursion
|
||||
if 'proxy' in d and name != 'fields' and name != 'proxy' and \
|
||||
not (name.startswith('__') and name.endswith('__')):
|
||||
return getattr(d['proxy'], name)
|
||||
else:
|
||||
return object.__getattribute__(self, name)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if 'proxy' in self.__dict__ and name is not 'proxy':
|
||||
setattr(self.__dict__['proxy'], name, value)
|
||||
else:
|
||||
self.__dict__[name] = value
|
||||
if 'proxy' in d:
|
||||
try:
|
||||
attr = getattr(d['proxy'], name)
|
||||
|
||||
if hasattr(attr, '__self__'):
|
||||
return attr
|
||||
except AttributeError:
|
||||
pass
|
||||
return object.__getattribute__(self, name)
|
||||
|
||||
|
||||
class SerializerMethodField(Field):
|
||||
|
|
|
@ -293,72 +293,6 @@ class TestCreateOnlyDefault:
|
|||
}
|
||||
|
||||
|
||||
# Tests for RecursiveField.
|
||||
# -------------------------
|
||||
|
||||
class TestRecursiveField:
|
||||
def setup(self):
|
||||
class LinkSerializer(serializers.Serializer):
|
||||
name = serializers.CharField()
|
||||
next = serializers.RecursiveField(required=False, allow_null=True)
|
||||
self.link_serializer = LinkSerializer
|
||||
|
||||
class NodeSerializer(serializers.Serializer):
|
||||
name = serializers.CharField()
|
||||
children = serializers.ListField(child=serializers.RecursiveField())
|
||||
self.node_serializer = NodeSerializer
|
||||
|
||||
def test_link_serializer(self):
|
||||
value = {
|
||||
'name': 'first',
|
||||
'next': {
|
||||
'name': 'second',
|
||||
'next': {
|
||||
'name': 'third',
|
||||
'next': None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# test serialization
|
||||
serializer = self.link_serializer(value)
|
||||
|
||||
assert serializer.data == value, \
|
||||
'serialized data does not match input'
|
||||
|
||||
# test deserialization
|
||||
serializer = self.link_serializer(data=value)
|
||||
|
||||
assert serializer.is_valid(), \
|
||||
'cannot validate on deserialization: %s' % dict(serializer.errors)
|
||||
assert serializer.validated_data == value, \
|
||||
'deserialized data does not match input'
|
||||
|
||||
def test_node_serializer(self):
|
||||
value = {
|
||||
'name': 'root',
|
||||
'children': [{
|
||||
'name': 'first child',
|
||||
'children': [],
|
||||
}, {
|
||||
'name': 'second child',
|
||||
'children': [],
|
||||
}]
|
||||
}
|
||||
|
||||
# serialization
|
||||
serializer = self.node_serializer(value)
|
||||
assert serializer.data == value, \
|
||||
'serialized data does not match input'
|
||||
|
||||
# deserialization
|
||||
serializer = self.node_serializer(data=value)
|
||||
assert serializer.is_valid(), \
|
||||
'cannot validate on deserialization: %s' % dict(serializer.errors)
|
||||
assert serializer.validated_data == value, \
|
||||
'deserialized data does not match input'
|
||||
|
||||
|
||||
# Tests for field input and output values.
|
||||
# ----------------------------------------
|
||||
|
||||
|
|
59
tests/test_recursive.py
Normal file
59
tests/test_recursive.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
from rest_framework import serializers
|
||||
|
||||
|
||||
class LinkSerializer(serializers.Serializer):
|
||||
name = serializers.CharField()
|
||||
next = serializers.RecursiveField(required=False, allow_null=True)
|
||||
|
||||
|
||||
class NodeSerializer(serializers.Serializer):
|
||||
name = serializers.CharField()
|
||||
children = serializers.ListField(child=serializers.RecursiveField())
|
||||
|
||||
|
||||
class TestRecursiveField:
|
||||
@staticmethod
|
||||
def serialize(serializer_class, value):
|
||||
serializer = serializer_class(value)
|
||||
|
||||
assert serializer.data == value, \
|
||||
'serialized data does not match input'
|
||||
|
||||
@staticmethod
|
||||
def deserialize(serializer_class, data):
|
||||
serializer = serializer_class(data=data)
|
||||
|
||||
assert serializer.is_valid(), \
|
||||
'cannot validate on deserialization: %s' % dict(serializer.errors)
|
||||
assert serializer.validated_data == data, \
|
||||
'deserialized data does not match input'
|
||||
|
||||
def test_link_serializer(self):
|
||||
value = {
|
||||
'name': 'first',
|
||||
'next': {
|
||||
'name': 'second',
|
||||
'next': {
|
||||
'name': 'third',
|
||||
'next': None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.serialize(LinkSerializer, value)
|
||||
self.deserialize(LinkSerializer, value)
|
||||
|
||||
def test_node_serializer(self):
|
||||
value = {
|
||||
'name': 'root',
|
||||
'children': [{
|
||||
'name': 'first child',
|
||||
'children': [],
|
||||
}, {
|
||||
'name': 'second child',
|
||||
'children': [],
|
||||
}]
|
||||
}
|
||||
|
||||
self.serialize(NodeSerializer, value)
|
||||
self.deserialize(NodeSerializer, value)
|
Loading…
Reference in New Issue
Block a user