This commit is contained in:
heywbj 2015-02-03 13:58:29 +00:00
commit 77caa782fd
2 changed files with 289 additions and 1 deletions

View File

@ -4,7 +4,7 @@ from django.core.exceptions import ObjectDoesNotExist
from django.core.exceptions import ValidationError as DjangoValidationError from django.core.exceptions import ValidationError as DjangoValidationError
from django.core.validators import RegexValidator from django.core.validators import RegexValidator
from django.forms import ImageField as DjangoImageField from django.forms import ImageField as DjangoImageField
from django.utils import six, timezone from django.utils import six, timezone, importlib
from django.utils.dateparse import parse_date, parse_datetime, parse_time from django.utils.dateparse import parse_date, parse_datetime, parse_time
from django.utils.encoding import is_protected_type, smart_text from django.utils.encoding import is_protected_type, smart_text
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
@ -1156,6 +1156,9 @@ class ListField(Field):
self.child = kwargs.pop('child', copy.deepcopy(self.child)) self.child = kwargs.pop('child', copy.deepcopy(self.child))
assert not inspect.isclass(self.child), '`child` has not been instantiated.' assert not inspect.isclass(self.child), '`child` has not been instantiated.'
super(ListField, self).__init__(*args, **kwargs) super(ListField, self).__init__(*args, **kwargs)
def bind(self, field_name, parent):
super(ListField, self).bind(field_name, parent)
self.child.bind(field_name='', parent=self) self.child.bind(field_name='', parent=self)
def get_value(self, dictionary): def get_value(self, dictionary):
@ -1270,6 +1273,104 @@ class HiddenField(Field):
return data return data
class RecursiveField(Field):
"""
A field that gets its representation from its parent.
This method could be used to serialize a tree structure, a linked list, or
even a directed acyclic graph. As with all recursive things, it is
important to keep the base case in mind. In the case of the tree serializer
example below, the base case is a node with an empty list of children. In
the case of the list serializer below, the base case is when `next==None`.
Above all, beware of cyclical references.
Examples:
class TreeSerializer(self):
children = ListField(child=RecursiveField())
class ListSerializer(self):
next = RecursiveField(allow_null=True)
"""
# This list of attributes determined by the attributes that
# `rest_framework.serializers` calls to on a field object
PROXIED_ATTRS = (
# methods
'get_value',
'get_initial',
'run_validation',
'get_attribute',
'to_representation',
# attributes
'field_name',
'source',
'read_only',
'default',
'source_attrs',
'write_only',
)
def __init__(self, to=None, **kwargs):
"""
arguments:
to - `None`, the name of another serializer defined in the same module
as this serializer, or the fully qualified import path to another
serializer. e.g. `ExampleSerializer` or
`path.to.module.ExampleSerializer`
"""
self.to = to
self.kwargs = kwargs
# Need to properly initialize by calling super-constructor for
# ModelSerializers
super_kwargs = dict(
(key, kwargs[key])
for key in kwargs
if key in inspect.getargspec(Field.__init__)
)
super(RecursiveField, self).__init__(**super_kwargs)
def bind(self, field_name, parent):
if hasattr(parent, 'child') and parent.child is self:
# RecursiveField nested inside of a ListField
parent_class = parent.parent.__class__
else:
# RecursiveField directly inside a Serializer
parent_class = parent.__class__
if self.to is None:
proxied_class = parent_class
else:
try:
module_name, class_name = self.to.rsplit('.', 1)
except ValueError:
module_name, class_name = parent_class.__module__, self.to
try:
proxied_class = getattr(
importlib.import_module(module_name), class_name)
except Exception as e:
raise ImportError(
'could not locate serializer %s' % self.to, e)
# Create a new serializer instance and proxy it
proxied = proxied_class(**self.kwargs)
proxied.bind(field_name, parent)
self.proxied = proxied
def __getattribute__(self, name):
if name in RecursiveField.PROXIED_ATTRS:
try:
proxied = object.__getattribute__(self, 'proxied')
return getattr(proxied, name)
except AttributeError:
pass
return object.__getattribute__(self, name)
class SerializerMethodField(Field): class SerializerMethodField(Field):
""" """
A read-only field that get its representation from calling a method on the A read-only field that get its representation from calling a method on the

187
tests/test_recursive.py Normal file
View File

@ -0,0 +1,187 @@
from django.db import models
from rest_framework import serializers
class LinkSerializer(serializers.Serializer):
name = serializers.CharField(max_length=25)
next = serializers.RecursiveField(allow_null=True)
class NodeSerializer(serializers.Serializer):
name = serializers.CharField()
children = serializers.ListField(child=serializers.RecursiveField())
class PingSerializer(serializers.Serializer):
ping_id = serializers.IntegerField()
pong = serializers.RecursiveField('PongSerializer', required=False)
class PongSerializer(serializers.Serializer):
pong_id = serializers.IntegerField()
ping = PingSerializer()
class SillySerializer(serializers.Serializer):
name = serializers.RecursiveField(
'rest_framework.fields.CharField', max_length=5)
blankable = serializers.RecursiveField(
'rest_framework.fields.CharField', allow_blank=True)
nullable = serializers.RecursiveField(
'rest_framework.fields.CharField', allow_null=True)
links = serializers.RecursiveField('LinkSerializer')
self = serializers.RecursiveField(required=False)
class RecursiveModel(models.Model):
name = models.CharField(max_length=255)
parent = models.ForeignKey('self', null=True)
class RecursiveModelSerializer(serializers.ModelSerializer):
parent = serializers.RecursiveField(allow_null=True)
class Meta:
model = RecursiveModel
fields = ('name', 'parent')
class TestRecursiveField:
@staticmethod
def serialize(serializer_class, value):
serializer = serializer_class(value)
assert serializer.data == value, \
'serialized data does not match input'
@staticmethod
def deserialize(serializer_class, data):
serializer = serializer_class(data=data)
assert serializer.is_valid(), \
'cannot validate on deserialization: %s' % dict(serializer.errors)
assert serializer.validated_data == data, \
'deserialized data does not match input'
def test_link_serializer(self):
value = {
'name': 'first',
'next': {
'name': 'second',
'next': {
'name': 'third',
'next': None,
}
}
}
self.serialize(LinkSerializer, value)
self.deserialize(LinkSerializer, value)
def test_node_serializer(self):
value = {
'name': 'root',
'children': [{
'name': 'first child',
'children': [],
}, {
'name': 'second child',
'children': [],
}]
}
self.serialize(NodeSerializer, value)
self.deserialize(NodeSerializer, value)
def test_ping_pong(self):
pong = {
'pong_id': 4,
'ping': {
'ping_id': 3,
'pong': {
'pong_id': 2,
'ping': {
'ping_id': 1,
},
},
},
}
self.serialize(PongSerializer, pong)
self.deserialize(PongSerializer, pong)
def test_validation(self):
value = {
'name': 'good',
'blankable': '',
'nullable': None,
'links': {
'name': 'something',
'next': {
'name': 'inner something',
'next': None,
}
}
}
self.serialize(SillySerializer, value)
self.deserialize(SillySerializer, value)
max_length = {
'name': 'too long',
'blankable': 'not blank',
'nullable': 'not null',
'links': {
'name': 'something',
'next': None,
}
}
serializer = SillySerializer(data=max_length)
assert not serializer.is_valid(), \
'validation should fail due to name too long'
nulled_out = {
'name': 'good',
'blankable': None,
'nullable': 'not null',
'links': {
'name': 'something',
'next': None,
}
}
serializer = SillySerializer(data=nulled_out)
assert not serializer.is_valid(), \
'validation should fail due to null field'
way_too_long = {
'name': 'good',
'blankable': '',
'nullable': None,
'links': {
'name': 'something',
'next': {
'name': 'inner something that is much too long',
'next': None,
}
}
}
serializer = SillySerializer(data=way_too_long)
assert not serializer.is_valid(), \
'validation should fail on inner link validation'
def test_model_serializer(self):
one = RecursiveModel(name='one')
two = RecursiveModel(name='two', parent=one)
# serialization
representation = {
'name': 'two',
'parent': {
'name': 'one',
'parent': None,
}
}
s = RecursiveModelSerializer(two)
assert s.data == representation
# deserialization
self.deserialize(RecursiveModelSerializer, representation)