From 89da9d7e16bcb87d4df7486b8f468c59c0ee8742 Mon Sep 17 00:00:00 2001 From: Ryan P Kilby Date: Fri, 23 Feb 2018 02:23:42 -0500 Subject: [PATCH] Add ListSerializer.get_child hook --- rest_framework/serializers.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index e35e04440..efd346521 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -572,6 +572,13 @@ class ListSerializer(BaseSerializer): super(ListSerializer, self).bind(field_name, parent) self.partial = self.parent.partial + def get_child(self, instance=None, data=empty): + """ + Hook to override retrieval of the child. By default, this returns the + child instance provided during initialization. + """ + return self.child + def get_initial(self): if hasattr(self, 'initial_data'): return self.to_representation(self.initial_data) @@ -636,7 +643,7 @@ class ListSerializer(BaseSerializer): for item in data: try: - validated = self.child.run_validation(item) + validated = self.get_child(data=item).run_validation(item) except ValidationError as exc: errors.append(exc.detail) else: @@ -657,7 +664,7 @@ class ListSerializer(BaseSerializer): iterable = data.all() if isinstance(data, models.Manager) else data return [ - self.child.to_representation(item) for item in iterable + self.get_child(instance=item).to_representation(item) for item in iterable ] def validate(self, attrs): @@ -674,7 +681,7 @@ class ListSerializer(BaseSerializer): def create(self, validated_data): return [ - self.child.create(attrs) for attrs in validated_data + self.get_child(data=attrs).create(attrs) for attrs in validated_data ] def save(self, **kwargs):