diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 71a9f1938..f15e0131d 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -4,7 +4,7 @@ from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ValidationError as DjangoValidationError from django.core.validators import RegexValidator from django.forms import ImageField as DjangoImageField -from django.utils import six, timezone +from django.utils import six, timezone, importlib from django.utils.dateparse import parse_date, parse_datetime, parse_time from django.utils.encoding import is_protected_type, smart_text from django.utils.translation import ugettext_lazy as _ @@ -1156,6 +1156,9 @@ class ListField(Field): self.child = kwargs.pop('child', copy.deepcopy(self.child)) assert not inspect.isclass(self.child), '`child` has not been instantiated.' super(ListField, self).__init__(*args, **kwargs) + + def bind(self, field_name, parent): + super(ListField, self).bind(field_name, parent) self.child.bind(field_name='', parent=self) def get_value(self, dictionary): @@ -1270,6 +1273,104 @@ class HiddenField(Field): return data +class RecursiveField(Field): + """ + A field that gets its representation from its parent. + + This method could be used to serialize a tree structure, a linked list, or + even a directed acyclic graph. As with all recursive things, it is + important to keep the base case in mind. In the case of the tree serializer + example below, the base case is a node with an empty list of children. In + the case of the list serializer below, the base case is when `next==None`. + Above all, beware of cyclical references. + + Examples: + + class TreeSerializer(self): + children = ListField(child=RecursiveField()) + + class ListSerializer(self): + next = RecursiveField(allow_null=True) + """ + + # This list of attributes determined by the attributes that + # `rest_framework.serializers` calls to on a field object + PROXIED_ATTRS = ( + # methods + 'get_value', + 'get_initial', + 'run_validation', + 'get_attribute', + 'to_representation', + + # attributes + 'field_name', + 'source', + 'read_only', + 'default', + 'source_attrs', + 'write_only', + ) + + def __init__(self, to=None, **kwargs): + """ + arguments: + to - `None`, the name of another serializer defined in the same module + as this serializer, or the fully qualified import path to another + serializer. e.g. `ExampleSerializer` or + `path.to.module.ExampleSerializer` + """ + self.to = to + self.kwargs = kwargs + + # Need to properly initialize by calling super-constructor for + # ModelSerializers + super_kwargs = dict( + (key, kwargs[key]) + for key in kwargs + if key in inspect.getargspec(Field.__init__) + ) + super(RecursiveField, self).__init__(**super_kwargs) + + def bind(self, field_name, parent): + if hasattr(parent, 'child') and parent.child is self: + # RecursiveField nested inside of a ListField + parent_class = parent.parent.__class__ + else: + # RecursiveField directly inside a Serializer + parent_class = parent.__class__ + + if self.to is None: + proxied_class = parent_class + else: + try: + module_name, class_name = self.to.rsplit('.', 1) + except ValueError: + module_name, class_name = parent_class.__module__, self.to + + try: + proxied_class = getattr( + importlib.import_module(module_name), class_name) + except Exception as e: + raise ImportError( + 'could not locate serializer %s' % self.to, e) + + # Create a new serializer instance and proxy it + proxied = proxied_class(**self.kwargs) + proxied.bind(field_name, parent) + self.proxied = proxied + + def __getattribute__(self, name): + if name in RecursiveField.PROXIED_ATTRS: + try: + proxied = object.__getattribute__(self, 'proxied') + return getattr(proxied, name) + except AttributeError: + pass + + return object.__getattribute__(self, name) + + class SerializerMethodField(Field): """ A read-only field that get its representation from calling a method on the diff --git a/tests/test_recursive.py b/tests/test_recursive.py new file mode 100644 index 000000000..fe8ec2f16 --- /dev/null +++ b/tests/test_recursive.py @@ -0,0 +1,187 @@ +from django.db import models +from rest_framework import serializers + + +class LinkSerializer(serializers.Serializer): + name = serializers.CharField(max_length=25) + next = serializers.RecursiveField(allow_null=True) + + +class NodeSerializer(serializers.Serializer): + name = serializers.CharField() + children = serializers.ListField(child=serializers.RecursiveField()) + + +class PingSerializer(serializers.Serializer): + ping_id = serializers.IntegerField() + pong = serializers.RecursiveField('PongSerializer', required=False) + + +class PongSerializer(serializers.Serializer): + pong_id = serializers.IntegerField() + ping = PingSerializer() + + +class SillySerializer(serializers.Serializer): + name = serializers.RecursiveField( + 'rest_framework.fields.CharField', max_length=5) + blankable = serializers.RecursiveField( + 'rest_framework.fields.CharField', allow_blank=True) + nullable = serializers.RecursiveField( + 'rest_framework.fields.CharField', allow_null=True) + links = serializers.RecursiveField('LinkSerializer') + self = serializers.RecursiveField(required=False) + + +class RecursiveModel(models.Model): + name = models.CharField(max_length=255) + parent = models.ForeignKey('self', null=True) + + +class RecursiveModelSerializer(serializers.ModelSerializer): + parent = serializers.RecursiveField(allow_null=True) + + class Meta: + model = RecursiveModel + fields = ('name', 'parent') + + +class TestRecursiveField: + @staticmethod + def serialize(serializer_class, value): + serializer = serializer_class(value) + + assert serializer.data == value, \ + 'serialized data does not match input' + + @staticmethod + def deserialize(serializer_class, data): + serializer = serializer_class(data=data) + + assert serializer.is_valid(), \ + 'cannot validate on deserialization: %s' % dict(serializer.errors) + assert serializer.validated_data == data, \ + 'deserialized data does not match input' + + def test_link_serializer(self): + value = { + 'name': 'first', + 'next': { + 'name': 'second', + 'next': { + 'name': 'third', + 'next': None, + } + } + } + + self.serialize(LinkSerializer, value) + self.deserialize(LinkSerializer, value) + + def test_node_serializer(self): + value = { + 'name': 'root', + 'children': [{ + 'name': 'first child', + 'children': [], + }, { + 'name': 'second child', + 'children': [], + }] + } + + self.serialize(NodeSerializer, value) + self.deserialize(NodeSerializer, value) + + def test_ping_pong(self): + pong = { + 'pong_id': 4, + 'ping': { + 'ping_id': 3, + 'pong': { + 'pong_id': 2, + 'ping': { + 'ping_id': 1, + }, + }, + }, + } + self.serialize(PongSerializer, pong) + self.deserialize(PongSerializer, pong) + + def test_validation(self): + value = { + 'name': 'good', + 'blankable': '', + 'nullable': None, + 'links': { + 'name': 'something', + 'next': { + 'name': 'inner something', + 'next': None, + } + } + } + self.serialize(SillySerializer, value) + self.deserialize(SillySerializer, value) + + max_length = { + 'name': 'too long', + 'blankable': 'not blank', + 'nullable': 'not null', + 'links': { + 'name': 'something', + 'next': None, + } + } + serializer = SillySerializer(data=max_length) + assert not serializer.is_valid(), \ + 'validation should fail due to name too long' + + nulled_out = { + 'name': 'good', + 'blankable': None, + 'nullable': 'not null', + 'links': { + 'name': 'something', + 'next': None, + } + } + serializer = SillySerializer(data=nulled_out) + assert not serializer.is_valid(), \ + 'validation should fail due to null field' + + way_too_long = { + 'name': 'good', + 'blankable': '', + 'nullable': None, + 'links': { + 'name': 'something', + 'next': { + 'name': 'inner something that is much too long', + 'next': None, + } + } + } + serializer = SillySerializer(data=way_too_long) + assert not serializer.is_valid(), \ + 'validation should fail on inner link validation' + + def test_model_serializer(self): + one = RecursiveModel(name='one') + two = RecursiveModel(name='two', parent=one) + + # serialization + representation = { + 'name': 'two', + 'parent': { + 'name': 'one', + 'parent': None, + } + } + + s = RecursiveModelSerializer(two) + assert s.data == representation + + # deserialization + self.deserialize(RecursiveModelSerializer, representation)