diff --git a/rest_framework/fields.py b/rest_framework/fields.py index d1aebbafc..446732c3a 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -109,7 +109,8 @@ class Field(object): def __init__(self, read_only=False, write_only=False, required=None, default=empty, initial=None, source=None, label=None, help_text=None, style=None, - error_messages=None, validators=[], allow_null=False): + error_messages=None, validators=[], allow_null=False, + context=None): self._creation_counter = Field._creation_counter Field._creation_counter += 1 @@ -135,6 +136,11 @@ class Field(object): self.validators = validators or self.default_validators[:] self.allow_null = allow_null + # These are set up by `.bind()` when the field is added to a serializer. + self.field_name = None + self.parent = None + self._context = {} if (context is None) else context + # Collect default error message from self and parent classes messages = {} for cls in reversed(self.__class__.__mro__): @@ -157,7 +163,14 @@ class Field(object): kwargs = copy.deepcopy(self._kwargs) return self.__class__(*args, **kwargs) - def bind(self, field_name, parent, root): + @property + def context(self): + root = self + while root.parent is not None: + root = root.parent + return root._context + + def bind(self, field_name, parent): """ Setup the context for the field instance. """ @@ -174,10 +187,8 @@ class Field(object): self.field_name = field_name self.parent = parent - self.root = root - self.context = parent.context - # `self.label` should deafult to being based on the field name. + # `self.label` should default to being based on the field name. if self.label is None: self.label = field_name.replace('_', ' ').capitalize() diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 5aa1f8bd8..b37a6fedd 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -243,11 +243,6 @@ class ManyRelation(Field): assert child_relation is not None, '`child_relation` is a required argument.' super(ManyRelation, self).__init__(*args, **kwargs) - def bind(self, field_name, parent, root): - # ManyRelation needs to provide the current context to the child relation. - super(ManyRelation, self).bind(field_name, parent, root) - self.child_relation.bind(field_name, parent, root) - def to_internal_value(self, data): return [ self.child_relation.to_internal_value(item) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 12e380900..04721c7a3 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -150,13 +150,20 @@ class SerializerMetaclass(type): class BindingDict(object): + """ + This dict-like object is used to store fields on a serializer. + + This ensures that whenever fields are added to the serializer we call + `field.bind()` so that the `field_name` and `parent` attributes + can be set correctly. + """ def __init__(self, serializer): self.serializer = serializer self.fields = SortedDict() def __setitem__(self, key, field): self.fields[key] = field - field.bind(field_name=key, parent=self.serializer, root=self.serializer) + field.bind(field_name=key, parent=self.serializer) def __getitem__(self, key): return self.fields[key] @@ -174,7 +181,6 @@ class BindingDict(object): @six.add_metaclass(SerializerMetaclass) class Serializer(BaseSerializer): def __init__(self, *args, **kwargs): - self.context = kwargs.pop('context', {}) kwargs.pop('partial', None) kwargs.pop('many', None) @@ -198,13 +204,6 @@ class Serializer(BaseSerializer): def _get_base_fields(self): return copy.deepcopy(self._declared_fields) - def bind(self, field_name, parent, root): - # If the serializer is used as a field then when it becomes bound - # it also needs to bind all its child fields. - super(Serializer, self).bind(field_name, parent, root) - for field_name, field in self.fields.items(): - field.bind(field_name, self, root) - def get_initial(self): return dict([ (field.field_name, field.get_initial()) @@ -290,17 +289,10 @@ class ListSerializer(BaseSerializer): self.child = kwargs.pop('child', copy.deepcopy(self.child)) assert self.child is not None, '`child` is a required argument.' assert not inspect.isclass(self.child), '`child` has not been instantiated.' - self.context = kwargs.pop('context', {}) kwargs.pop('partial', None) super(ListSerializer, self).__init__(*args, **kwargs) - self.child.bind('', self, self) - - def bind(self, field_name, parent, root): - # If the list is used as a field then it needs to provide - # the current context to the child serializer. - super(ListSerializer, self).bind(field_name, parent, root) - self.child.bind(field_name, self, root) + self.child.bind(field_name='', parent=self) def get_value(self, dictionary): # We override the default field access in order to support diff --git a/tests/test_relations.py b/tests/test_relations.py index c29618ce2..2d11672b8 100644 --- a/tests/test_relations.py +++ b/tests/test_relations.py @@ -51,7 +51,7 @@ class TestHyperlinkedIdentityField(APISimpleTestCase): self.instance = MockObject(pk=1, name='foo') self.field = serializers.HyperlinkedIdentityField(view_name='example') self.field.reverse = mock_reverse - self.field.context = {'request': True} + self.field._context = {'request': True} def test_representation(self): representation = self.field.to_representation(self.instance) @@ -62,7 +62,7 @@ class TestHyperlinkedIdentityField(APISimpleTestCase): assert representation is None def test_representation_with_format(self): - self.field.context['format'] = 'xml' + self.field._context['format'] = 'xml' representation = self.field.to_representation(self.instance) assert representation == 'http://example.org/example/1.xml/' @@ -91,14 +91,14 @@ class TestHyperlinkedIdentityFieldWithFormat(APISimpleTestCase): self.instance = MockObject(pk=1, name='foo') self.field = serializers.HyperlinkedIdentityField(view_name='example', format='json') self.field.reverse = mock_reverse - self.field.context = {'request': True} + self.field._context = {'request': True} def test_representation(self): representation = self.field.to_representation(self.instance) assert representation == 'http://example.org/example/1/' def test_representation_with_format(self): - self.field.context['format'] = 'xml' + self.field._context['format'] = 'xml' representation = self.field.to_representation(self.instance) assert representation == 'http://example.org/example/1.json/'