Refactor many

This commit is contained in:
Tom Christie 2014-11-13 21:11:13 +00:00
parent 78a741be27
commit 992330055e
2 changed files with 37 additions and 10 deletions

View File

@ -10,9 +10,17 @@ from django.utils.translation import ugettext_lazy as _
class PKOnlyObject(object): class PKOnlyObject(object):
"""
This is a mock object, used for when we only need the pk of the object
instance, but still want to return an object with a .pk attribute,
in order to keep the same interface as a regular model instance.
"""
def __init__(self, pk): def __init__(self, pk):
self.pk = pk self.pk = pk
# We assume that 'validators' are intended for the child serializer,
# rather than the parent serializer.
MANY_RELATION_KWARGS = ( MANY_RELATION_KWARGS = (
'read_only', 'write_only', 'required', 'default', 'initial', 'source', 'read_only', 'write_only', 'required', 'default', 'initial', 'source',
'label', 'help_text', 'style', 'error_messages' 'label', 'help_text', 'style', 'error_messages'
@ -36,13 +44,17 @@ class RelatedField(Field):
# We override this method in order to automagically create # We override this method in order to automagically create
# `ManyRelatedField` classes instead when `many=True` is set. # `ManyRelatedField` classes instead when `many=True` is set.
if kwargs.pop('many', False): if kwargs.pop('many', False):
list_kwargs = {'child_relation': cls(*args, **kwargs)} return cls.many_init(*args, **kwargs)
for key in kwargs.keys():
if key in MANY_RELATION_KWARGS:
list_kwargs[key] = kwargs[key]
return ManyRelatedField(**list_kwargs)
return super(RelatedField, cls).__new__(cls, *args, **kwargs) return super(RelatedField, cls).__new__(cls, *args, **kwargs)
@classmethod
def many_init(cls, *args, **kwargs):
list_kwargs = {'child_relation': cls(*args, **kwargs)}
for key in kwargs.keys():
if key in MANY_RELATION_KWARGS:
list_kwargs[key] = kwargs[key]
return ManyRelatedField(**list_kwargs)
def run_validation(self, data=empty): def run_validation(self, data=empty):
# We force empty strings to None values for relational fields. # We force empty strings to None values for relational fields.
if data == '': if data == '':

View File

@ -46,6 +46,9 @@ import warnings
from rest_framework.relations import * # NOQA from rest_framework.relations import * # NOQA
from rest_framework.fields import * # NOQA from rest_framework.fields import * # NOQA
# We assume that 'validators' are intended for the child serializer,
# rather than the parent serializer.
LIST_SERIALIZER_KWARGS = ( LIST_SERIALIZER_KWARGS = (
'read_only', 'write_only', 'required', 'default', 'initial', 'source', 'read_only', 'write_only', 'required', 'default', 'initial', 'source',
'label', 'help_text', 'style', 'error_messages', 'label', 'help_text', 'style', 'error_messages',
@ -73,13 +76,25 @@ class BaseSerializer(Field):
# We override this method in order to automagically create # We override this method in order to automagically create
# `ListSerializer` classes instead when `many=True` is set. # `ListSerializer` classes instead when `many=True` is set.
if kwargs.pop('many', False): if kwargs.pop('many', False):
list_kwargs = {'child': cls(*args, **kwargs)} return cls.many_init(*args, **kwargs)
for key in kwargs.keys():
if key in LIST_SERIALIZER_KWARGS:
list_kwargs[key] = kwargs[key]
return ListSerializer(*args, **list_kwargs)
return super(BaseSerializer, cls).__new__(cls, *args, **kwargs) return super(BaseSerializer, cls).__new__(cls, *args, **kwargs)
@classmethod
def many_init(cls, *args, **kwargs):
"""
This method implements the creation of a `ListSerializer` parent
class when `many=True` is used. You can customize it if you need to
control which keyword arguments are passed to the parent, and
which are passed to the child.
"""
child_serializer = cls(*args, **kwargs)
list_kwargs = {'child': child_serializer}
list_kwargs.update(dict([
(key, value) for key, value in kwargs.items()
if key in LIST_SERIALIZER_KWARGS
]))
return ListSerializer(*args, **list_kwargs)
def to_internal_value(self, data): def to_internal_value(self, data):
raise NotImplementedError('`to_internal_value()` must be implemented.') raise NotImplementedError('`to_internal_value()` must be implemented.')