From a9f5551b03913a37a54f13c957c8641011a78b05 Mon Sep 17 00:00:00 2001 From: Oskar Hahn Date: Tue, 7 Jul 2015 12:50:32 +0200 Subject: [PATCH] Made it easier to change the field name of ModelSerializer. Added a way to override the ManyRelatedField class to use by an RelatedField The ManyRelatedField class was not changed by this commit, it had to be moved before the RelatedField class. See #3121 --- rest_framework/relations.py | 120 ++++++++++++++++++------------------ 1 file changed, 61 insertions(+), 59 deletions(-) diff --git a/rest_framework/relations.py b/rest_framework/relations.py index c5cbfebcd..509e4ffc8 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -35,7 +35,67 @@ MANY_RELATION_KWARGS = ( ) +class ManyRelatedField(Field): + """ + Relationships with `many=True` transparently get coerced into instead being + a ManyRelatedField with a child relationship. + + The `ManyRelatedField` class is responsible for handling iterating through + the values and passing each one to the child relationship. + + This class is treated as private API. + You shouldn't generally need to be using this class directly yourself, + and should instead simply set 'many=True' on the relationship. + """ + initial = [] + default_empty_html = [] + + def __init__(self, child_relation=None, *args, **kwargs): + self.child_relation = child_relation + assert child_relation is not None, '`child_relation` is a required argument.' + super(ManyRelatedField, self).__init__(*args, **kwargs) + self.child_relation.bind(field_name='', parent=self) + + def get_value(self, dictionary): + # We override the default field access in order to support + # lists in HTML forms. + if html.is_html_input(dictionary): + # Don't return [] if the update is partial + if self.field_name not in dictionary: + if getattr(self.root, 'partial', False): + return empty + return dictionary.getlist(self.field_name) + + return dictionary.get(self.field_name, empty) + + def to_internal_value(self, data): + return [ + self.child_relation.to_internal_value(item) + for item in data + ] + + def get_attribute(self, instance): + # Can't have any relationships if not created + if hasattr(instance, 'pk') and instance.pk is None: + return [] + + relationship = get_attribute(instance, self.source_attrs) + return relationship.all() if (hasattr(relationship, 'all')) else relationship + + def to_representation(self, iterable): + return [ + self.child_relation.to_representation(value) + for value in iterable + ] + + @property + def choices(self): + return self.child_relation.choices + + class RelatedField(Field): + many_related_field_class = ManyRelatedField + def __init__(self, **kwargs): self.queryset = kwargs.pop('queryset', None) assert self.queryset is not None or kwargs.get('read_only', None), ( @@ -77,7 +137,7 @@ class RelatedField(Field): for key in kwargs.keys(): if key in MANY_RELATION_KWARGS: list_kwargs[key] = kwargs[key] - return ManyRelatedField(**list_kwargs) + return cls.many_related_field_class(**list_kwargs) def run_validation(self, data=empty): # We force empty strings to None values for relational fields. @@ -338,61 +398,3 @@ class SlugRelatedField(RelatedField): def to_representation(self, obj): return getattr(obj, self.slug_field) - - -class ManyRelatedField(Field): - """ - Relationships with `many=True` transparently get coerced into instead being - a ManyRelatedField with a child relationship. - - The `ManyRelatedField` class is responsible for handling iterating through - the values and passing each one to the child relationship. - - This class is treated as private API. - You shouldn't generally need to be using this class directly yourself, - and should instead simply set 'many=True' on the relationship. - """ - initial = [] - default_empty_html = [] - - def __init__(self, child_relation=None, *args, **kwargs): - self.child_relation = child_relation - assert child_relation is not None, '`child_relation` is a required argument.' - super(ManyRelatedField, self).__init__(*args, **kwargs) - self.child_relation.bind(field_name='', parent=self) - - def get_value(self, dictionary): - # We override the default field access in order to support - # lists in HTML forms. - if html.is_html_input(dictionary): - # Don't return [] if the update is partial - if self.field_name not in dictionary: - if getattr(self.root, 'partial', False): - return empty - return dictionary.getlist(self.field_name) - - return dictionary.get(self.field_name, empty) - - def to_internal_value(self, data): - return [ - self.child_relation.to_internal_value(item) - for item in data - ] - - def get_attribute(self, instance): - # Can't have any relationships if not created - if hasattr(instance, 'pk') and instance.pk is None: - return [] - - relationship = get_attribute(instance, self.source_attrs) - return relationship.all() if (hasattr(relationship, 'all')) else relationship - - def to_representation(self, iterable): - return [ - self.child_relation.to_representation(value) - for value in iterable - ] - - @property - def choices(self): - return self.child_relation.choices