mirror of
https://github.com/encode/django-rest-framework.git
synced 2024-11-22 09:36:49 +03:00
Fix ModelSerializer unique_together handling for field sources (#7143)
* Fix ModelSerializer unique_together field sources Updates ModelSerializer to check for serializer fields that map to the model field sources in the unique_together lists. * Ensure field name ordering consistency
This commit is contained in:
parent
00e6079e94
commit
089162e6e3
|
@ -13,7 +13,7 @@ response content is handled by parsers and renderers.
|
||||||
import copy
|
import copy
|
||||||
import inspect
|
import inspect
|
||||||
import traceback
|
import traceback
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict, defaultdict
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
|
|
||||||
from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured
|
from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured
|
||||||
|
@ -1508,28 +1508,55 @@ class ModelSerializer(Serializer):
|
||||||
# which may map onto a model field. Any dotted field name lookups
|
# which may map onto a model field. Any dotted field name lookups
|
||||||
# cannot map to a field, and must be a traversal, so we're not
|
# cannot map to a field, and must be a traversal, so we're not
|
||||||
# including those.
|
# including those.
|
||||||
field_names = {
|
field_sources = OrderedDict(
|
||||||
field.source for field in self._writable_fields
|
(field.field_name, field.source) for field in self._writable_fields
|
||||||
if (field.source != '*') and ('.' not in field.source)
|
if (field.source != '*') and ('.' not in field.source)
|
||||||
}
|
)
|
||||||
|
|
||||||
# Special Case: Add read_only fields with defaults.
|
# Special Case: Add read_only fields with defaults.
|
||||||
field_names |= {
|
field_sources.update(OrderedDict(
|
||||||
field.source for field in self.fields.values()
|
(field.field_name, field.source) for field in self.fields.values()
|
||||||
if (field.read_only) and (field.default != empty) and (field.source != '*') and ('.' not in field.source)
|
if (field.read_only) and (field.default != empty) and (field.source != '*') and ('.' not in field.source)
|
||||||
}
|
))
|
||||||
|
|
||||||
|
# Invert so we can find the serializer field names that correspond to
|
||||||
|
# the model field names in the unique_together sets. This also allows
|
||||||
|
# us to check that multiple fields don't map to the same source.
|
||||||
|
source_map = defaultdict(list)
|
||||||
|
for name, source in field_sources.items():
|
||||||
|
source_map[source].append(name)
|
||||||
|
|
||||||
# Note that we make sure to check `unique_together` both on the
|
# Note that we make sure to check `unique_together` both on the
|
||||||
# base model class, but also on any parent classes.
|
# base model class, but also on any parent classes.
|
||||||
validators = []
|
validators = []
|
||||||
for parent_class in model_class_inheritance_tree:
|
for parent_class in model_class_inheritance_tree:
|
||||||
for unique_together in parent_class._meta.unique_together:
|
for unique_together in parent_class._meta.unique_together:
|
||||||
if field_names.issuperset(set(unique_together)):
|
# Skip if serializer does not map to all unique together sources
|
||||||
validator = UniqueTogetherValidator(
|
if not set(source_map).issuperset(set(unique_together)):
|
||||||
queryset=parent_class._default_manager,
|
continue
|
||||||
fields=unique_together
|
|
||||||
|
for source in unique_together:
|
||||||
|
assert len(source_map[source]) == 1, (
|
||||||
|
"Unable to create `UniqueTogetherValidator` for "
|
||||||
|
"`{model}.{field}` as `{serializer}` has multiple "
|
||||||
|
"fields ({fields}) that map to this model field. "
|
||||||
|
"Either remove the extra fields, or override "
|
||||||
|
"`Meta.validators` with a `UniqueTogetherValidator` "
|
||||||
|
"using the desired field names."
|
||||||
|
.format(
|
||||||
|
model=self.Meta.model.__name__,
|
||||||
|
serializer=self.__class__.__name__,
|
||||||
|
field=source,
|
||||||
|
fields=', '.join(source_map[source]),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
validators.append(validator)
|
|
||||||
|
field_names = tuple(source_map[f][0] for f in unique_together)
|
||||||
|
validator = UniqueTogetherValidator(
|
||||||
|
queryset=parent_class._default_manager,
|
||||||
|
fields=field_names
|
||||||
|
)
|
||||||
|
validators.append(validator)
|
||||||
return validators
|
return validators
|
||||||
|
|
||||||
def get_unique_for_date_validators(self):
|
def get_unique_for_date_validators(self):
|
||||||
|
|
|
@ -344,6 +344,49 @@ class TestUniquenessTogetherValidation(TestCase):
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def test_default_validator_with_fields_with_source(self):
|
||||||
|
class TestSerializer(serializers.ModelSerializer):
|
||||||
|
name = serializers.CharField(source='race_name')
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
model = UniquenessTogetherModel
|
||||||
|
fields = ['name', 'position']
|
||||||
|
|
||||||
|
serializer = TestSerializer()
|
||||||
|
expected = dedent("""
|
||||||
|
TestSerializer():
|
||||||
|
name = CharField(source='race_name')
|
||||||
|
position = IntegerField()
|
||||||
|
class Meta:
|
||||||
|
validators = [<UniqueTogetherValidator(queryset=UniquenessTogetherModel.objects.all(), fields=('name', 'position'))>]
|
||||||
|
""")
|
||||||
|
assert repr(serializer) == expected
|
||||||
|
|
||||||
|
def test_default_validator_with_multiple_fields_with_same_source(self):
|
||||||
|
class TestSerializer(serializers.ModelSerializer):
|
||||||
|
name = serializers.CharField(source='race_name')
|
||||||
|
other_name = serializers.CharField(source='race_name')
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
model = UniquenessTogetherModel
|
||||||
|
fields = ['name', 'other_name', 'position']
|
||||||
|
|
||||||
|
serializer = TestSerializer(data={
|
||||||
|
'name': 'foo',
|
||||||
|
'other_name': 'foo',
|
||||||
|
'position': 1,
|
||||||
|
})
|
||||||
|
with pytest.raises(AssertionError) as excinfo:
|
||||||
|
serializer.is_valid()
|
||||||
|
|
||||||
|
expected = (
|
||||||
|
"Unable to create `UniqueTogetherValidator` for "
|
||||||
|
"`UniquenessTogetherModel.race_name` as `TestSerializer` has "
|
||||||
|
"multiple fields (name, other_name) that map to this model field. "
|
||||||
|
"Either remove the extra fields, or override `Meta.validators` "
|
||||||
|
"with a `UniqueTogetherValidator` using the desired field names.")
|
||||||
|
assert str(excinfo.value) == expected
|
||||||
|
|
||||||
def test_allow_explict_override(self):
|
def test_allow_explict_override(self):
|
||||||
"""
|
"""
|
||||||
Ensure validators can be explicitly removed..
|
Ensure validators can be explicitly removed..
|
||||||
|
|
Loading…
Reference in New Issue
Block a user