Add ListSerializer.get_child hook

This commit is contained in:
Ryan P Kilby 2018-02-23 02:23:42 -05:00
parent d2994e0596
commit 89da9d7e16

View File

@ -572,6 +572,13 @@ class ListSerializer(BaseSerializer):
super(ListSerializer, self).bind(field_name, parent) super(ListSerializer, self).bind(field_name, parent)
self.partial = self.parent.partial self.partial = self.parent.partial
def get_child(self, instance=None, data=empty):
"""
Hook to override retrieval of the child. By default, this returns the
child instance provided during initialization.
"""
return self.child
def get_initial(self): def get_initial(self):
if hasattr(self, 'initial_data'): if hasattr(self, 'initial_data'):
return self.to_representation(self.initial_data) return self.to_representation(self.initial_data)
@ -636,7 +643,7 @@ class ListSerializer(BaseSerializer):
for item in data: for item in data:
try: try:
validated = self.child.run_validation(item) validated = self.get_child(data=item).run_validation(item)
except ValidationError as exc: except ValidationError as exc:
errors.append(exc.detail) errors.append(exc.detail)
else: else:
@ -657,7 +664,7 @@ class ListSerializer(BaseSerializer):
iterable = data.all() if isinstance(data, models.Manager) else data iterable = data.all() if isinstance(data, models.Manager) else data
return [ return [
self.child.to_representation(item) for item in iterable self.get_child(instance=item).to_representation(item) for item in iterable
] ]
def validate(self, attrs): def validate(self, attrs):
@ -674,7 +681,7 @@ class ListSerializer(BaseSerializer):
def create(self, validated_data): def create(self, validated_data):
return [ return [
self.child.create(attrs) for attrs in validated_data self.get_child(data=attrs).create(attrs) for attrs in validated_data
] ]
def save(self, **kwargs): def save(self, **kwargs):