From 1ece516d2d0d942d9de513f85d601afcccf67ebd Mon Sep 17 00:00:00 2001 From: Si Feng Date: Tue, 19 Feb 2019 07:38:20 -0800 Subject: [PATCH] Adjusted field `validators` to accept iterables. (#6282) Closes 6280. --- rest_framework/fields.py | 4 ++-- rest_framework/serializers.py | 4 ++-- tests/test_fields.py | 19 +++++++++++++++++++ tests/test_serializer.py | 34 +++++++++++++++++++++++++++++++++- 4 files changed, 56 insertions(+), 5 deletions(-) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 562e52b22..2cbfd22bb 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -350,7 +350,7 @@ class Field(object): self.default_empty_html = default if validators is not None: - self.validators = validators[:] + self.validators = list(validators) # These are set up by `.bind()` when the field is added to a serializer. self.field_name = None @@ -410,7 +410,7 @@ class Field(object): self._validators = validators def get_validators(self): - return self.default_validators[:] + return list(self.default_validators) def get_initial(self): """ diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index eae08a34c..9830edb3f 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -393,7 +393,7 @@ class Serializer(BaseSerializer): # Used by the lazily-evaluated `validators` property. meta = getattr(self, 'Meta', None) validators = getattr(meta, 'validators', None) - return validators[:] if validators else [] + return list(validators) if validators else [] def get_initial(self): if hasattr(self, 'initial_data'): @@ -1480,7 +1480,7 @@ class ModelSerializer(Serializer): # If the validators have been declared explicitly then use that. validators = getattr(getattr(self, 'Meta', None), 'validators', None) if validators is not None: - return validators[:] + return list(validators) # Otherwise use the default set of validators. return ( diff --git a/tests/test_fields.py b/tests/test_fields.py index 9a1d04979..12c936b22 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -740,6 +740,25 @@ class TestCharField(FieldValues): 'Null characters are not allowed.' ] + def test_iterable_validators(self): + """ + Ensure `validators` parameter is compatible with reasonable iterables. + """ + value = 'example' + + for validators in ([], (), set()): + field = serializers.CharField(validators=validators) + field.run_validation(value) + + def raise_exception(value): + raise exceptions.ValidationError('Raised error') + + for validators in ([raise_exception], (raise_exception,), set([raise_exception])): + field = serializers.CharField(validators=validators) + with pytest.raises(serializers.ValidationError) as exc_info: + field.run_validation(value) + assert exc_info.value.detail == ['Raised error'] + class TestEmailField(FieldValues): """ diff --git a/tests/test_serializer.py b/tests/test_serializer.py index efa1adf0e..6e4ff22b2 100644 --- a/tests/test_serializer.py +++ b/tests/test_serializer.py @@ -10,7 +10,7 @@ from collections import Mapping import pytest from django.db import models -from rest_framework import fields, relations, serializers +from rest_framework import exceptions, fields, relations, serializers from rest_framework.compat import unicode_repr from rest_framework.fields import Field @@ -183,6 +183,38 @@ class TestSerializer: assert serializer.validated_data.coords[1] == 50.941357 assert serializer.errors == {} + def test_iterable_validators(self): + """ + Ensure `validators` parameter is compatible with reasonable iterables. + """ + data = {'char': 'abc', 'integer': 123} + + for validators in ([], (), set()): + class ExampleSerializer(serializers.Serializer): + char = serializers.CharField(validators=validators) + integer = serializers.IntegerField() + + serializer = ExampleSerializer(data=data) + assert serializer.is_valid() + assert serializer.validated_data == data + assert serializer.errors == {} + + def raise_exception(value): + raise exceptions.ValidationError('Raised error') + + for validators in ([raise_exception], (raise_exception,), set([raise_exception])): + class ExampleSerializer(serializers.Serializer): + char = serializers.CharField(validators=validators) + integer = serializers.IntegerField() + + serializer = ExampleSerializer(data=data) + assert not serializer.is_valid() + assert serializer.data == data + assert serializer.validated_data == {} + assert serializer.errors == {'char': [ + exceptions.ErrorDetail(string='Raised error', code='invalid') + ]} + class TestValidateMethod: def test_non_field_error_validate_method(self):