diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 5105dfcb4..5fb99a42f 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -209,8 +209,10 @@ class Field(object): """ Validate a simple representation and return the internal value. - The provided data may be `empty` if no representation was included. - May return `empty` if the field should not be included in the + The provided data may be `empty` if no representation was included + in the input. + + May raise `SkipField` if the field should not be included in the validated data. """ if data is empty: @@ -223,6 +225,10 @@ class Field(object): return value def run_validators(self, value): + """ + Test the given value against all the validators on the field, + and either raise a `ValidationError` or simply return. + """ if value in (None, '', [], (), {}): return @@ -753,8 +759,9 @@ class MultipleChoiceField(ChoiceField): } def to_internal_value(self, data): - if not hasattr(data, '__iter__'): + if isinstance(data, type('')) or not hasattr(data, '__iter__'): self.fail('not_a_list', input_type=type(data).__name__) + return set([ super(MultipleChoiceField, self).to_internal_value(item) for item in data diff --git a/tests/test_fields.py b/tests/test_fields.py index ae7f19193..e03ece544 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -5,22 +5,31 @@ import datetime import pytest +def get_items(mapping_or_list_of_two_tuples): + # Tests accept either lists of two tuples, or dictionaries. + if isinstance(mapping_or_list_of_two_tuples, dict): + # {value: expected} + return mapping_or_list_of_two_tuples.items() + # [(value, expected), ...] + return mapping_or_list_of_two_tuples + + class ValidAndInvalidValues: """ - Base class for testing valid and invalid field values. + Base class for testing valid and invalid input values. """ def test_valid_values(self): """ Ensure that valid values return the expected validated data. """ - for input_value, expected_output in self.valid_mappings.items(): + for input_value, expected_output in get_items(self.valid_mappings): assert self.field.run_validation(input_value) == expected_output def test_invalid_values(self): """ Ensure that invalid values raise the expected validation error. """ - for input_value, expected_failure in self.invalid_mappings.items(): + for input_value, expected_failure in get_items(self.invalid_mappings): with pytest.raises(fields.ValidationError) as exc_info: self.field.run_validation(input_value) assert exc_info.value.messages == expected_failure @@ -189,14 +198,14 @@ class TestDecimalField(ValidAndInvalidValues): 12.3: Decimal('12.3'), 0.1: Decimal('0.1'), } - invalid_mappings = { - 'abc': ["A valid number is required."], - Decimal('Nan'): ["A valid number is required."], - Decimal('Inf'): ["A valid number is required."], - '12.345': ["Ensure that there are no more than 3 digits in total."], - '0.01': ["Ensure that there are no more than 1 decimal places."], - 123: ["Ensure that there are no more than 2 digits before the decimal point."] - } + invalid_mappings = ( + ('abc', ["A valid number is required."]), + (Decimal('Nan'), ["A valid number is required."]), + (Decimal('Inf'), ["A valid number is required."]), + ('12.345', ["Ensure that there are no more than 3 digits in total."]), + ('0.01', ["Ensure that there are no more than 1 decimal places."]), + (123, ["Ensure that there are no more than 2 digits before the decimal point."]) + ) field = fields.DecimalField(max_digits=3, decimal_places=1)