Improve ListSerializer

This commit is contained in:
Felipe Martins Diel 2024-01-27 03:34:08 -03:00
parent 41edb3b9dd
commit e20af60767

View File

@ -28,7 +28,7 @@ from django.utils.translation import gettext_lazy as _
from rest_framework.compat import postgres_fields from rest_framework.compat import postgres_fields
from rest_framework.exceptions import ErrorDetail, ValidationError from rest_framework.exceptions import ErrorDetail, ValidationError
from rest_framework.fields import get_error_detail from rest_framework.fields import get_attribute, get_error_detail
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.utils import html, model_meta, representation from rest_framework.utils import html, model_meta, representation
from rest_framework.utils.field_mapping import ( from rest_framework.utils.field_mapping import (
@ -610,11 +610,6 @@ class ListSerializer(BaseSerializer):
assert self.child is not None, '`child` is a required argument.' assert self.child is not None, '`child` is a required argument.'
assert not inspect.isclass(self.child), '`child` has not been instantiated.' assert not inspect.isclass(self.child), '`child` has not been instantiated.'
instance = kwargs.get('instance', [])
data = kwargs.get('data', [])
if instance and data:
assert len(data) == len(instance), 'Data and instance should have same length'
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.child.bind(field_name='', parent=self) self.child.bind(field_name='', parent=self)
@ -700,13 +695,37 @@ class ListSerializer(BaseSerializer):
ret = [] ret = []
errors = [] errors = []
for idx, item in enumerate(data): if not self.instance and self.parent and self.parent.instance:
if ( rel_mgr = get_attribute(self.parent.instance, self.source_attrs)
hasattr(self, 'instance') self.instance = rel_mgr.all() if rel_mgr else None
and self.instance
and len(self.instance) > idx pk_field_name = next(
): field_name
self.child.instance = self.instance[idx] for field_name, field in self.child.fields.items()
if field.source == self.child.Meta.model._meta.pk.attname
)
item_instance_dict = {}
if self.instance:
item_pks = []
for item in data:
item_pk = item.get(pk_field_name)
if item_pk:
item_pks.append(item_pk)
for item_instance in self.instance.filter(pk__in=item_pks):
item_instance_dict[item_instance.pk] = item_instance
for item in data:
self.child.initial_data = item
if self.instance:
item_pk = item.get(pk_field_name)
self.child.instance = (
item_instance_dict.get(item_pk) if item_pk else None
)
try: try:
validated = self.run_child_validation(item) validated = self.run_child_validation(item)
except ValidationError as exc: except ValidationError as exc: