Clean up bind - no longer needs to be called multiple times in nested fields

This commit is contained in:
Tom Christie 2014-09-25 11:40:32 +01:00
parent b22c9602fa
commit 64632da371
4 changed files with 29 additions and 31 deletions

View File

@ -109,7 +109,8 @@ class Field(object):
def __init__(self, read_only=False, write_only=False, def __init__(self, read_only=False, write_only=False,
required=None, default=empty, initial=None, source=None, required=None, default=empty, initial=None, source=None,
label=None, help_text=None, style=None, label=None, help_text=None, style=None,
error_messages=None, validators=[], allow_null=False): error_messages=None, validators=[], allow_null=False,
context=None):
self._creation_counter = Field._creation_counter self._creation_counter = Field._creation_counter
Field._creation_counter += 1 Field._creation_counter += 1
@ -135,6 +136,11 @@ class Field(object):
self.validators = validators or self.default_validators[:] self.validators = validators or self.default_validators[:]
self.allow_null = allow_null self.allow_null = allow_null
# These are set up by `.bind()` when the field is added to a serializer.
self.field_name = None
self.parent = None
self._context = {} if (context is None) else context
# Collect default error message from self and parent classes # Collect default error message from self and parent classes
messages = {} messages = {}
for cls in reversed(self.__class__.__mro__): for cls in reversed(self.__class__.__mro__):
@ -157,7 +163,14 @@ class Field(object):
kwargs = copy.deepcopy(self._kwargs) kwargs = copy.deepcopy(self._kwargs)
return self.__class__(*args, **kwargs) return self.__class__(*args, **kwargs)
def bind(self, field_name, parent, root): @property
def context(self):
root = self
while root.parent is not None:
root = root.parent
return root._context
def bind(self, field_name, parent):
""" """
Setup the context for the field instance. Setup the context for the field instance.
""" """
@ -174,10 +187,8 @@ class Field(object):
self.field_name = field_name self.field_name = field_name
self.parent = parent self.parent = parent
self.root = root
self.context = parent.context
# `self.label` should deafult to being based on the field name. # `self.label` should default to being based on the field name.
if self.label is None: if self.label is None:
self.label = field_name.replace('_', ' ').capitalize() self.label = field_name.replace('_', ' ').capitalize()

View File

@ -243,11 +243,6 @@ class ManyRelation(Field):
assert child_relation is not None, '`child_relation` is a required argument.' assert child_relation is not None, '`child_relation` is a required argument.'
super(ManyRelation, self).__init__(*args, **kwargs) super(ManyRelation, self).__init__(*args, **kwargs)
def bind(self, field_name, parent, root):
# ManyRelation needs to provide the current context to the child relation.
super(ManyRelation, self).bind(field_name, parent, root)
self.child_relation.bind(field_name, parent, root)
def to_internal_value(self, data): def to_internal_value(self, data):
return [ return [
self.child_relation.to_internal_value(item) self.child_relation.to_internal_value(item)

View File

@ -150,13 +150,20 @@ class SerializerMetaclass(type):
class BindingDict(object): class BindingDict(object):
"""
This dict-like object is used to store fields on a serializer.
This ensures that whenever fields are added to the serializer we call
`field.bind()` so that the `field_name` and `parent` attributes
can be set correctly.
"""
def __init__(self, serializer): def __init__(self, serializer):
self.serializer = serializer self.serializer = serializer
self.fields = SortedDict() self.fields = SortedDict()
def __setitem__(self, key, field): def __setitem__(self, key, field):
self.fields[key] = field self.fields[key] = field
field.bind(field_name=key, parent=self.serializer, root=self.serializer) field.bind(field_name=key, parent=self.serializer)
def __getitem__(self, key): def __getitem__(self, key):
return self.fields[key] return self.fields[key]
@ -174,7 +181,6 @@ class BindingDict(object):
@six.add_metaclass(SerializerMetaclass) @six.add_metaclass(SerializerMetaclass)
class Serializer(BaseSerializer): class Serializer(BaseSerializer):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.context = kwargs.pop('context', {})
kwargs.pop('partial', None) kwargs.pop('partial', None)
kwargs.pop('many', None) kwargs.pop('many', None)
@ -198,13 +204,6 @@ class Serializer(BaseSerializer):
def _get_base_fields(self): def _get_base_fields(self):
return copy.deepcopy(self._declared_fields) return copy.deepcopy(self._declared_fields)
def bind(self, field_name, parent, root):
# If the serializer is used as a field then when it becomes bound
# it also needs to bind all its child fields.
super(Serializer, self).bind(field_name, parent, root)
for field_name, field in self.fields.items():
field.bind(field_name, self, root)
def get_initial(self): def get_initial(self):
return dict([ return dict([
(field.field_name, field.get_initial()) (field.field_name, field.get_initial())
@ -290,17 +289,10 @@ class ListSerializer(BaseSerializer):
self.child = kwargs.pop('child', copy.deepcopy(self.child)) self.child = kwargs.pop('child', copy.deepcopy(self.child))
assert self.child is not None, '`child` is a required argument.' assert self.child is not None, '`child` is a required argument.'
assert not inspect.isclass(self.child), '`child` has not been instantiated.' assert not inspect.isclass(self.child), '`child` has not been instantiated.'
self.context = kwargs.pop('context', {})
kwargs.pop('partial', None) kwargs.pop('partial', None)
super(ListSerializer, self).__init__(*args, **kwargs) super(ListSerializer, self).__init__(*args, **kwargs)
self.child.bind('', self, self) self.child.bind(field_name='', parent=self)
def bind(self, field_name, parent, root):
# If the list is used as a field then it needs to provide
# the current context to the child serializer.
super(ListSerializer, self).bind(field_name, parent, root)
self.child.bind(field_name, self, root)
def get_value(self, dictionary): def get_value(self, dictionary):
# We override the default field access in order to support # We override the default field access in order to support

View File

@ -51,7 +51,7 @@ class TestHyperlinkedIdentityField(APISimpleTestCase):
self.instance = MockObject(pk=1, name='foo') self.instance = MockObject(pk=1, name='foo')
self.field = serializers.HyperlinkedIdentityField(view_name='example') self.field = serializers.HyperlinkedIdentityField(view_name='example')
self.field.reverse = mock_reverse self.field.reverse = mock_reverse
self.field.context = {'request': True} self.field._context = {'request': True}
def test_representation(self): def test_representation(self):
representation = self.field.to_representation(self.instance) representation = self.field.to_representation(self.instance)
@ -62,7 +62,7 @@ class TestHyperlinkedIdentityField(APISimpleTestCase):
assert representation is None assert representation is None
def test_representation_with_format(self): def test_representation_with_format(self):
self.field.context['format'] = 'xml' self.field._context['format'] = 'xml'
representation = self.field.to_representation(self.instance) representation = self.field.to_representation(self.instance)
assert representation == 'http://example.org/example/1.xml/' assert representation == 'http://example.org/example/1.xml/'
@ -91,14 +91,14 @@ class TestHyperlinkedIdentityFieldWithFormat(APISimpleTestCase):
self.instance = MockObject(pk=1, name='foo') self.instance = MockObject(pk=1, name='foo')
self.field = serializers.HyperlinkedIdentityField(view_name='example', format='json') self.field = serializers.HyperlinkedIdentityField(view_name='example', format='json')
self.field.reverse = mock_reverse self.field.reverse = mock_reverse
self.field.context = {'request': True} self.field._context = {'request': True}
def test_representation(self): def test_representation(self):
representation = self.field.to_representation(self.instance) representation = self.field.to_representation(self.instance)
assert representation == 'http://example.org/example/1/' assert representation == 'http://example.org/example/1/'
def test_representation_with_format(self): def test_representation_with_format(self):
self.field.context['format'] = 'xml' self.field._context['format'] = 'xml'
representation = self.field.to_representation(self.instance) representation = self.field.to_representation(self.instance)
assert representation == 'http://example.org/example/1.json/' assert representation == 'http://example.org/example/1.json/'