diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index d773902e8..4f853ed58 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -571,6 +571,13 @@ class ListSerializer(BaseSerializer): super(ListSerializer, self).bind(field_name, parent) self.partial = self.parent.partial + def get_child(self, instance=None, data=empty): + """ + Hook to override retrieval of the child. By default, this returns the + child instance provided during initialization. + """ + return self.child + def get_initial(self): if hasattr(self, 'initial_data'): return self.to_representation(self.initial_data) @@ -635,7 +642,7 @@ class ListSerializer(BaseSerializer): for item in data: try: - validated = self.child.run_validation(item) + validated = self.get_child(data=item).run_validation(item) except ValidationError as exc: errors.append(exc.detail) else: @@ -656,7 +663,7 @@ class ListSerializer(BaseSerializer): iterable = data.all() if isinstance(data, models.Manager) else data return [ - self.child.to_representation(item) for item in iterable + self.get_child(instance=item).to_representation(item) for item in iterable ] def validate(self, attrs): @@ -673,7 +680,7 @@ class ListSerializer(BaseSerializer): def create(self, validated_data): return [ - self.child.create(attrs) for attrs in validated_data + self.get_child(data=attrs).create(attrs) for attrs in validated_data ] def save(self, **kwargs):