diff --git a/rest_framework/fields.py b/rest_framework/fields.py index cce597771..c9110aa37 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -4,7 +4,7 @@ from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ValidationError as DjangoValidationError from django.core.validators import RegexValidator from django.forms import ImageField as DjangoImageField -from django.utils import six, timezone +from django.utils import six, timezone, importlib from django.utils.dateparse import parse_date, parse_datetime, parse_time from django.utils.encoding import is_protected_type, smart_text from django.utils.translation import ugettext_lazy as _ @@ -21,7 +21,6 @@ import collections import copy import datetime import decimal -import importlib import inspect import re import uuid @@ -1294,47 +1293,62 @@ class RecursiveField(Field): next = RecursiveField(allow_null=True) """ - def __init__(self, to='self', to_module=None, **kwargs): + # This list of attributes determined by the attributes that + # `rest_framework.serializers` calls to on a field object + PROXIED_ATTRS = ( + # bound fields + 'get_value', + 'get_initial', + 'run_validation', + 'get_attribute', + 'to_representation', + + # attributes + 'field_name', + 'source', + 'read_only', + 'default', + 'source_attrs', + 'write_only', + ) + + def __init__(self, to=None, **kwargs): self.to = to - self.to_module = to_module - field_kwargs = dict( - (key, kwargs[key]) - for key in kwargs - if key in inspect.getargspec(Field.__init__) - ) - super(RecursiveField, self).__init__(**field_kwargs) + self.kwargs = kwargs def bind(self, field_name, parent): - super(RecursiveField, self).bind(field_name, parent) - if hasattr(parent, 'child') and parent.child is self: parent_class = parent.parent.__class__ else: parent_class = parent.__class__ - if self.to == 'self': - proxy_class = parent_class + if self.to is None: + proxied_class = parent_class else: - ref = importlib.import_module(self.to_module or parent_class.__module__) - for part in self.to.split('.'): - ref = getattr(ref, part) - proxy_class = ref + try: + module_name, class_name = self.to.rsplit('.', 1) + except ValueError: + module_name, class_name = parent_class.__module__, self.to - proxy = proxy_class(**self._kwargs) - proxy.bind(field_name, parent) - self.proxy = proxy + try: + proxied_class = getattr( + importlib.import_module(module_name), class_name) + except Exception as e: + raise ImportError( + 'could not locate serializer %s' % self.to, e) + + proxied = proxied_class(**self.kwargs) + proxied.bind(field_name, parent) + self.proxied = proxied def __getattribute__(self, name): - d = object.__getattribute__(self, '__dict__') - - if 'proxy' in d: + if name in RecursiveField.PROXIED_ATTRS: try: - attr = getattr(d['proxy'], name) - - if hasattr(attr, '__self__'): - return attr + proxied = object.__getattribute__(self, 'proxied') + return getattr(proxied, name) except AttributeError: pass + return object.__getattribute__(self, name) diff --git a/tests/test_recursive.py b/tests/test_recursive.py index b343bd727..f3e60d2b2 100644 --- a/tests/test_recursive.py +++ b/tests/test_recursive.py @@ -23,11 +23,11 @@ class PongSerializer(serializers.Serializer): class SillySerializer(serializers.Serializer): name = serializers.RecursiveField( - 'CharField', 'rest_framework.fields', max_length=5) + 'rest_framework.fields.CharField', max_length=5) blankable = serializers.RecursiveField( - 'CharField', 'rest_framework.fields', allow_blank=True) + 'rest_framework.fields.CharField', allow_blank=True) nullable = serializers.RecursiveField( - 'CharField', 'rest_framework.fields', allow_null=True) + 'rest_framework.fields.CharField', allow_null=True) links = serializers.RecursiveField('LinkSerializer') self = serializers.RecursiveField(required=False)