diff --git a/rest_framework/fields.py b/rest_framework/fields.py index aad24f4ed..06a1ea6ff 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -21,6 +21,7 @@ import collections import copy import datetime import decimal +import importlib import inspect import re import uuid @@ -1311,7 +1312,9 @@ class RecursiveField(Field): # __init__ on both the RecursiveField and the proxied field using the exact # same arguments. - def __init__(self, **kwargs): + def __init__(self, to='self', to_module=None, **kwargs): + self.to = to + self.to_module = to_module field_kwargs = dict( (key, kwargs[key]) for key in kwargs @@ -1323,9 +1326,17 @@ class RecursiveField(Field): super(RecursiveField, self).bind(field_name, parent) if hasattr(parent, 'child') and parent.child is self: - proxy_class = parent.parent.__class__ + parent_class = parent.parent.__class__ else: - proxy_class = parent.__class__ + parent_class = parent.__class__ + + if self.to == 'self': + proxy_class = parent_class + else: + ref = importlib.import_module(self.to_module or parent_class.__module__) + for part in self.to.split('.'): + ref = getattr(ref, part) + proxy_class = ref proxy = proxy_class(**self._kwargs) proxy.bind(field_name, parent) diff --git a/tests/test_recursive.py b/tests/test_recursive.py index 04a38c1c7..c675b3979 100644 --- a/tests/test_recursive.py +++ b/tests/test_recursive.py @@ -1,8 +1,7 @@ from rest_framework import serializers - class LinkSerializer(serializers.Serializer): - name = serializers.CharField() + name = serializers.CharField(max_length=25) next = serializers.RecursiveField(required=False, allow_null=True) @@ -11,6 +10,26 @@ class NodeSerializer(serializers.Serializer): children = serializers.ListField(child=serializers.RecursiveField()) +class PingSerializer(serializers.Serializer): + ping_id = serializers.IntegerField() + pong = serializers.RecursiveField('PongSerializer', required=False) + + +class PongSerializer(serializers.Serializer): + pong_id = serializers.IntegerField() + ping = PingSerializer() + + +class SillySerializer(serializers.Serializer): + name = serializers.RecursiveField( + 'CharField', 'rest_framework.fields', max_length=5) + blankable = serializers.RecursiveField( + 'CharField', 'rest_framework.fields', allow_blank=True) + nullable = serializers.RecursiveField( + 'CharField', 'rest_framework.fields', allow_null=True) + links = serializers.RecursiveField('LinkSerializer') + self = serializers.RecursiveField(required=False) + class TestRecursiveField: @staticmethod def serialize(serializer_class, value): @@ -57,3 +76,73 @@ class TestRecursiveField: self.serialize(NodeSerializer, value) self.deserialize(NodeSerializer, value) + + def test_ping_pong(self): + pong = { + 'pong_id': 4, + 'ping': { + 'ping_id': 3, + 'pong': { + 'pong_id': 2, + 'ping': { + 'ping_id': 1, + }, + }, + }, + } + self.serialize(PongSerializer, pong) + self.deserialize(PongSerializer, pong) + + def test_validation(self): + value = { + 'name': 'good', + 'blankable': '', + 'nullable': None, + 'links': { + 'name': 'something', + 'next': { + 'name': 'inner something', + } + } + } + self.serialize(SillySerializer, value) + self.deserialize(SillySerializer, value) + + max_length = { + 'name': 'too long', + 'blankable': 'not blank', + 'nullable': 'not null', + 'links': { + 'name': 'something', + } + } + serializer = SillySerializer(data=max_length) + assert not serializer.is_valid(), \ + 'validation should fail due to name too long' + + nulled_out = { + 'name': 'good', + 'blankable': None, + 'nullable': 'not null', + 'links': { + 'name': 'something', + } + } + serializer = SillySerializer(data=nulled_out) + assert not serializer.is_valid(), \ + 'validation should fail due to null field' + + way_too_long = { + 'name': 'good', + 'blankable': '', + 'nullable': None, + 'links': { + 'name': 'something', + 'next': { + 'name': 'inner something that is much too long', + } + } + } + serializer = SillySerializer(data=way_too_long) + assert not serializer.is_valid(), \ + 'validation should fail on inner link validation'