Fix compat issues

This commit is contained in:
Tom Christie 2014-09-22 14:54:33 +01:00
parent c54f394904
commit 249253a144
2 changed files with 30 additions and 14 deletions

View File

@ -209,8 +209,10 @@ class Field(object):
""" """
Validate a simple representation and return the internal value. Validate a simple representation and return the internal value.
The provided data may be `empty` if no representation was included. The provided data may be `empty` if no representation was included
May return `empty` if the field should not be included in the in the input.
May raise `SkipField` if the field should not be included in the
validated data. validated data.
""" """
if data is empty: if data is empty:
@ -223,6 +225,10 @@ class Field(object):
return value return value
def run_validators(self, value): def run_validators(self, value):
"""
Test the given value against all the validators on the field,
and either raise a `ValidationError` or simply return.
"""
if value in (None, '', [], (), {}): if value in (None, '', [], (), {}):
return return
@ -753,8 +759,9 @@ class MultipleChoiceField(ChoiceField):
} }
def to_internal_value(self, data): def to_internal_value(self, data):
if not hasattr(data, '__iter__'): if isinstance(data, type('')) or not hasattr(data, '__iter__'):
self.fail('not_a_list', input_type=type(data).__name__) self.fail('not_a_list', input_type=type(data).__name__)
return set([ return set([
super(MultipleChoiceField, self).to_internal_value(item) super(MultipleChoiceField, self).to_internal_value(item)
for item in data for item in data

View File

@ -5,22 +5,31 @@ import datetime
import pytest import pytest
def get_items(mapping_or_list_of_two_tuples):
# Tests accept either lists of two tuples, or dictionaries.
if isinstance(mapping_or_list_of_two_tuples, dict):
# {value: expected}
return mapping_or_list_of_two_tuples.items()
# [(value, expected), ...]
return mapping_or_list_of_two_tuples
class ValidAndInvalidValues: class ValidAndInvalidValues:
""" """
Base class for testing valid and invalid field values. Base class for testing valid and invalid input values.
""" """
def test_valid_values(self): def test_valid_values(self):
""" """
Ensure that valid values return the expected validated data. Ensure that valid values return the expected validated data.
""" """
for input_value, expected_output in self.valid_mappings.items(): for input_value, expected_output in get_items(self.valid_mappings):
assert self.field.run_validation(input_value) == expected_output assert self.field.run_validation(input_value) == expected_output
def test_invalid_values(self): def test_invalid_values(self):
""" """
Ensure that invalid values raise the expected validation error. Ensure that invalid values raise the expected validation error.
""" """
for input_value, expected_failure in self.invalid_mappings.items(): for input_value, expected_failure in get_items(self.invalid_mappings):
with pytest.raises(fields.ValidationError) as exc_info: with pytest.raises(fields.ValidationError) as exc_info:
self.field.run_validation(input_value) self.field.run_validation(input_value)
assert exc_info.value.messages == expected_failure assert exc_info.value.messages == expected_failure
@ -189,14 +198,14 @@ class TestDecimalField(ValidAndInvalidValues):
12.3: Decimal('12.3'), 12.3: Decimal('12.3'),
0.1: Decimal('0.1'), 0.1: Decimal('0.1'),
} }
invalid_mappings = { invalid_mappings = (
'abc': ["A valid number is required."], ('abc', ["A valid number is required."]),
Decimal('Nan'): ["A valid number is required."], (Decimal('Nan'), ["A valid number is required."]),
Decimal('Inf'): ["A valid number is required."], (Decimal('Inf'), ["A valid number is required."]),
'12.345': ["Ensure that there are no more than 3 digits in total."], ('12.345', ["Ensure that there are no more than 3 digits in total."]),
'0.01': ["Ensure that there are no more than 1 decimal places."], ('0.01', ["Ensure that there are no more than 1 decimal places."]),
123: ["Ensure that there are no more than 2 digits before the decimal point."] (123, ["Ensure that there are no more than 2 digits before the decimal point."])
} )
field = fields.DecimalField(max_digits=3, decimal_places=1) field = fields.DecimalField(max_digits=3, decimal_places=1)