More test sorting

This commit is contained in:
Tom Christie 2014-09-08 14:24:05 +01:00
parent d934824bff
commit 21980b800d
10 changed files with 190 additions and 115 deletions

View File

@ -29,11 +29,6 @@ class ParseError(APIException):
default_detail = 'Malformed request.' default_detail = 'Malformed request.'
class ValidationError(APIException):
status_code = status.HTTP_400_BAD_REQUEST
default_detail = 'Invalid data in request.'
class AuthenticationFailed(APIException): class AuthenticationFailed(APIException):
status_code = status.HTTP_401_UNAUTHORIZED status_code = status.HTTP_401_UNAUTHORIZED
default_detail = 'Incorrect authentication credentials.' default_detail = 'Incorrect authentication credentials.'

View File

@ -1,4 +1,6 @@
from rest_framework.exceptions import ValidationError from django.core import validators
from django.core.exceptions import ValidationError
from django.utils.encoding import is_protected_type
from rest_framework.utils import html from rest_framework.utils import html
import inspect import inspect
@ -33,9 +35,14 @@ def get_attribute(instance, attrs):
""" """
Similar to Python's built in `getattr(instance, attr)`, Similar to Python's built in `getattr(instance, attr)`,
but takes a list of nested attributes, instead of a single attribute. but takes a list of nested attributes, instead of a single attribute.
Also accepts either attribute lookup on objects or dictionary lookups.
""" """
for attr in attrs: for attr in attrs:
instance = getattr(instance, attr) try:
instance = getattr(instance, attr)
except AttributeError:
return instance[attr]
return instance return instance
@ -80,9 +87,11 @@ class Field(object):
'not exist in the `MESSAGES` dictionary.' 'not exist in the `MESSAGES` dictionary.'
) )
default_validators = []
def __init__(self, read_only=False, write_only=False, def __init__(self, read_only=False, write_only=False,
required=None, default=empty, initial=None, source=None, required=None, default=empty, initial=None, source=None,
label=None, style=None, error_messages=None): label=None, style=None, error_messages=None, validators=[]):
self._creation_counter = Field._creation_counter self._creation_counter = Field._creation_counter
Field._creation_counter += 1 Field._creation_counter += 1
@ -104,6 +113,7 @@ class Field(object):
self.initial = initial self.initial = initial
self.label = label self.label = label
self.style = {} if style is None else style self.style = {} if style is None else style
self.validators = self.default_validators + validators
def bind(self, field_name, parent, root): def bind(self, field_name, parent, root):
""" """
@ -176,8 +186,21 @@ class Field(object):
self.fail('required') self.fail('required')
return self.get_default() return self.get_default()
self.run_validators(data)
return self.to_native(data) return self.to_native(data)
def run_validators(self, value):
if value in validators.EMPTY_VALUES:
return
errors = []
for validator in self.validators:
try:
validator(value)
except ValidationError as exc:
errors.extend(exc.messages)
if errors:
raise ValidationError(errors)
def to_native(self, data): def to_native(self, data):
""" """
Transform the *incoming* primative data into a native value. Transform the *incoming* primative data into a native value.
@ -322,9 +345,13 @@ class IntegerField(Field):
} }
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.max_value = kwargs.pop('max_value') max_value = kwargs.pop('max_value', None)
self.min_value = kwargs.pop('min_value') min_value = kwargs.pop('min_value', None)
super(CharField, self).__init__(**kwargs) super(IntegerField, self).__init__(**kwargs)
if max_value is not None:
self.validators.append(validators.MaxValueValidator(max_value))
if min_value is not None:
self.validators.append(validators.MinValueValidator(min_value))
def to_native(self, data): def to_native(self, data):
try: try:
@ -392,3 +419,49 @@ class MethodField(Field):
attr = 'get_{field_name}'.format(field_name=self.field_name) attr = 'get_{field_name}'.format(field_name=self.field_name)
method = getattr(self.parent, attr) method = getattr(self.parent, attr)
return method(value) return method(value)
class ModelField(Field):
"""
A generic field that can be used against an arbitrary model field.
"""
def __init__(self, *args, **kwargs):
try:
self.model_field = kwargs.pop('model_field')
except KeyError:
raise ValueError("ModelField requires 'model_field' kwarg")
self.min_length = kwargs.pop('min_length',
getattr(self.model_field, 'min_length', None))
self.max_length = kwargs.pop('max_length',
getattr(self.model_field, 'max_length', None))
self.min_value = kwargs.pop('min_value',
getattr(self.model_field, 'min_value', None))
self.max_value = kwargs.pop('max_value',
getattr(self.model_field, 'max_value', None))
super(ModelField, self).__init__(*args, **kwargs)
if self.min_length is not None:
self.validators.append(validators.MinLengthValidator(self.min_length))
if self.max_length is not None:
self.validators.append(validators.MaxLengthValidator(self.max_length))
if self.min_value is not None:
self.validators.append(validators.MinValueValidator(self.min_value))
if self.max_value is not None:
self.validators.append(validators.MaxValueValidator(self.max_value))
def get_attribute(self, instance):
return get_attribute(instance, self.source_attrs[:-1])
def to_native(self, data):
rel = getattr(self.model_field, 'rel', None)
if rel is not None:
return rel.to._meta.get_field(rel.field_name).to_python(data)
return self.model_field.to_python(data)
def to_primative(self, obj):
value = self.model_field._get_val_from_obj(obj)
if is_protected_type(value):
return value
return self.model_field.value_to_string(obj)

View File

@ -10,10 +10,10 @@ python primitives.
2. The process of marshalling between python primitives and request and 2. The process of marshalling between python primitives and request and
response content is handled by parsers and renderers. response content is handled by parsers and renderers.
""" """
from django.core.exceptions import ValidationError
from django.db import models from django.db import models
from django.utils import six from django.utils import six
from collections import namedtuple, OrderedDict from collections import namedtuple, OrderedDict
from rest_framework.exceptions import ValidationError
from rest_framework.fields import empty, set_value, Field, SkipField from rest_framework.fields import empty, set_value, Field, SkipField
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.utils import html from rest_framework.utils import html
@ -58,13 +58,14 @@ class BaseSerializer(Field):
raise NotImplementedError('`create()` must be implemented.') raise NotImplementedError('`create()` must be implemented.')
def save(self, extras=None): def save(self, extras=None):
attrs = self.validated_data
if extras is not None: if extras is not None:
self.validated_data.update(extras) attrs = dict(list(attrs.items()) + list(extras.items()))
if self.instance is not None: if self.instance is not None:
self.update(self.instance, self._validated_data) self.update(self.instance, attrs)
else: else:
self.instance = self.create(self._validated_data) self.instance = self.create(attrs)
return self.instance return self.instance
@ -74,7 +75,7 @@ class BaseSerializer(Field):
self._validated_data = self.to_native(self._initial_data) self._validated_data = self.to_native(self._initial_data)
except ValidationError as exc: except ValidationError as exc:
self._validated_data = {} self._validated_data = {}
self._errors = exc.detail self._errors = exc.message_dict
else: else:
self._errors = {} self._errors = {}
@ -210,7 +211,7 @@ class Serializer(BaseSerializer):
if validate_method is not None: if validate_method is not None:
validated_value = validate_method(validated_value) validated_value = validate_method(validated_value)
except ValidationError as exc: except ValidationError as exc:
errors[field.field_name] = str(exc) errors[field.field_name] = exc.messages
except SkipField: except SkipField:
pass pass
else: else:
@ -219,8 +220,10 @@ class Serializer(BaseSerializer):
if errors: if errors:
raise ValidationError(errors) raise ValidationError(errors)
# TODO: 'Non field errors' try:
return self.validate(ret) return self.validate(ret)
except ValidationError, exc:
raise ValidationError({'non_field_errors': exc.messages})
def to_primative(self, instance): def to_primative(self, instance):
""" """
@ -539,6 +542,9 @@ class ModelSerializer(Serializer):
if model_field.verbose_name is not None: if model_field.verbose_name is not None:
kwargs['label'] = model_field.verbose_name kwargs['label'] = model_field.verbose_name
if model_field.validators is not None:
kwargs['validators'] = model_field.validators
# if model_field.help_text is not None: # if model_field.help_text is not None:
# kwargs['help_text'] = model_field.help_text # kwargs['help_text'] = model_field.help_text
@ -577,8 +583,7 @@ class ModelSerializer(Serializer):
try: try:
return self.field_mapping[model_field.__class__](**kwargs) return self.field_mapping[model_field.__class__](**kwargs)
except KeyError: except KeyError:
# TODO: Change this to `return ModelField(model_field=model_field, **kwargs)` return ModelField(model_field=model_field, **kwargs)
return CharField(**kwargs)
class HyperlinkedModelSerializerOptions(ModelSerializerOptions): class HyperlinkedModelSerializerOptions(ModelSerializerOptions):

View File

@ -3,7 +3,7 @@ Provides an APIView class that is the base of all views in REST framework.
""" """
from __future__ import unicode_literals from __future__ import unicode_literals
from django.core.exceptions import PermissionDenied from django.core.exceptions import PermissionDenied, ValidationError
from django.http import Http404 from django.http import Http404
from django.utils.datastructures import SortedDict from django.utils.datastructures import SortedDict
from django.views.decorators.csrf import csrf_exempt from django.views.decorators.csrf import csrf_exempt
@ -51,7 +51,8 @@ def exception_handler(exc):
Returns the response that should be used for any given exception. Returns the response that should be used for any given exception.
By default we handle the REST framework `APIException`, and also By default we handle the REST framework `APIException`, and also
Django's builtin `Http404` and `PermissionDenied` exceptions. Django's built-in `ValidationError`, `Http404` and `PermissionDenied`
exceptions.
Any unhandled exceptions may return `None`, which will cause a 500 error Any unhandled exceptions may return `None`, which will cause a 500 error
to be raised. to be raised.
@ -68,6 +69,10 @@ def exception_handler(exc):
status=exc.status_code, status=exc.status_code,
headers=headers) headers=headers)
elif isinstance(exc, ValidationError):
return Response(exc.message_dict,
status=status.HTTP_400_BAD_REQUEST)
elif isinstance(exc, Http404): elif isinstance(exc, Http404):
return Response({'detail': 'Not found'}, return Response({'detail': 'Not found'},
status=status.HTTP_404_NOT_FOUND) status=status.HTTP_404_NOT_FOUND)

View File

@ -0,0 +1,33 @@
# From test_validation...
class TestPreSaveValidationExclusions(TestCase):
def test_pre_save_validation_exclusions(self):
"""
Somewhat weird test case to ensure that we don't perform model
validation on read only fields.
"""
obj = ValidationModel.objects.create(blank_validated_field='')
request = factory.put('/', {}, format='json')
view = UpdateValidationModel().as_view()
response = view(request, pk=obj.pk).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
# From test_permissions...
class ModelPermissionsIntegrationTests(TestCase):
def setUp(...):
...
def test_has_put_as_create_permissions(self):
# User only has update permissions - should be able to update an entity.
request = factory.put('/1', {'text': 'foobar'}, format='json',
HTTP_AUTHORIZATION=self.updateonly_credentials)
response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK)
# But if PUTing to a new entity, permission should be denied.
request = factory.put('/2', {'text': 'foobar'}, format='json',
HTTP_AUTHORIZATION=self.updateonly_credentials)
response = instance_view(request, pk='2')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)

View File

@ -95,19 +95,6 @@ class ModelPermissionsIntegrationTests(TestCase):
response = instance_view(request, pk=1) response = instance_view(request, pk=1)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
def test_has_put_as_create_permissions(self):
# User only has update permissions - should be able to update an entity.
request = factory.put('/1', {'text': 'foobar'}, format='json',
HTTP_AUTHORIZATION=self.updateonly_credentials)
response = instance_view(request, pk='1')
self.assertEqual(response.status_code, status.HTTP_200_OK)
# But if PUTing to a new entity, permission should be denied.
request = factory.put('/2', {'text': 'foobar'}, format='json',
HTTP_AUTHORIZATION=self.updateonly_credentials)
response = instance_view(request, pk='2')
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
# def test_options_permitted(self): # def test_options_permitted(self):
# request = factory.options( # request = factory.options(
# '/', # '/',

View File

@ -225,8 +225,8 @@ class Issue467Tests(TestCase):
def test_form_has_label_and_help_text(self): def test_form_has_label_and_help_text(self):
resp = self.client.get('/html_new_model') resp = self.client.get('/html_new_model')
self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
self.assertContains(resp, 'Text comes here') # self.assertContains(resp, 'Text comes here')
self.assertContains(resp, 'Text description.') # self.assertContains(resp, 'Text description.')
class Issue807Tests(TestCase): class Issue807Tests(TestCase):
@ -270,11 +270,11 @@ class Issue807Tests(TestCase):
) )
resp = self.client.get('/html_new_model_viewset/' + param) resp = self.client.get('/html_new_model_viewset/' + param)
self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
self.assertContains(resp, 'Text comes here') # self.assertContains(resp, 'Text comes here')
self.assertContains(resp, 'Text description.') # self.assertContains(resp, 'Text description.')
def test_form_has_label_and_help_text(self): def test_form_has_label_and_help_text(self):
resp = self.client.get('/html_new_model') resp = self.client.get('/html_new_model')
self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8') self.assertEqual(resp['Content-Type'], 'text/html; charset=utf-8')
self.assertContains(resp, 'Text comes here') # self.assertContains(resp, 'Text comes here')
self.assertContains(resp, 'Text description.') # self.assertContains(resp, 'Text description.')

View File

@ -1,31 +1,31 @@
# from django.test import TestCase from django.test import TestCase
# from django.utils import six from django.utils import six
# from rest_framework.serializers import _resolve_model from rest_framework.serializers import _resolve_model
# from tests.models import BasicModel from tests.models import BasicModel
# class ResolveModelTests(TestCase): class ResolveModelTests(TestCase):
# """ """
# `_resolve_model` should return a Django model class given the `_resolve_model` should return a Django model class given the
# provided argument is a Django model class itself, or a properly provided argument is a Django model class itself, or a properly
# formatted string representation of one. formatted string representation of one.
# """ """
# def test_resolve_django_model(self): def test_resolve_django_model(self):
# resolved_model = _resolve_model(BasicModel) resolved_model = _resolve_model(BasicModel)
# self.assertEqual(resolved_model, BasicModel) self.assertEqual(resolved_model, BasicModel)
# def test_resolve_string_representation(self): def test_resolve_string_representation(self):
# resolved_model = _resolve_model('tests.BasicModel') resolved_model = _resolve_model('tests.BasicModel')
# self.assertEqual(resolved_model, BasicModel) self.assertEqual(resolved_model, BasicModel)
# def test_resolve_unicode_representation(self): def test_resolve_unicode_representation(self):
# resolved_model = _resolve_model(six.text_type('tests.BasicModel')) resolved_model = _resolve_model(six.text_type('tests.BasicModel'))
# self.assertEqual(resolved_model, BasicModel) self.assertEqual(resolved_model, BasicModel)
# def test_resolve_non_django_model(self): def test_resolve_non_django_model(self):
# with self.assertRaises(ValueError): with self.assertRaises(ValueError):
# _resolve_model(TestCase) _resolve_model(TestCase)
# def test_resolve_improper_string_representation(self): def test_resolve_improper_string_representation(self):
# with self.assertRaises(ValueError): with self.assertRaises(ValueError):
# _resolve_model('BasicModel') _resolve_model('BasicModel')

View File

@ -26,19 +26,6 @@ class UpdateValidationModel(generics.RetrieveUpdateDestroyAPIView):
serializer_class = ValidationModelSerializer serializer_class = ValidationModelSerializer
class TestPreSaveValidationExclusions(TestCase):
def test_pre_save_validation_exclusions(self):
"""
Somewhat weird test case to ensure that we don't perform model
validation on read only fields.
"""
obj = ValidationModel.objects.create(blank_validated_field='')
request = factory.put('/', {}, format='json')
view = UpdateValidationModel().as_view()
response = view(request, pk=obj.pk).render()
self.assertEqual(response.status_code, status.HTTP_200_OK)
# Regression for #653 # Regression for #653
class ShouldValidateModel(models.Model): class ShouldValidateModel(models.Model):

View File

@ -1,41 +1,31 @@
# from django.db import models from django.test import TestCase
# from django.test import TestCase from rest_framework import serializers
# from rest_framework import serializers
# class ExampleModel(models.Model): class WriteOnlyFieldTests(TestCase):
# email = models.EmailField(max_length=100) def setUp(self):
# password = models.CharField(max_length=100) class ExampleSerializer(serializers.Serializer):
email = serializers.EmailField()
password = serializers.CharField(write_only=True)
def create(self, attrs):
return attrs
# class WriteOnlyFieldTests(TestCase): self.Serializer = ExampleSerializer
# def test_write_only_fields(self):
# class ExampleSerializer(serializers.Serializer):
# email = serializers.EmailField()
# password = serializers.CharField(write_only=True)
# data = { def write_only_fields_are_present_on_input(self):
# 'email': 'foo@example.com', data = {
# 'password': '123' 'email': 'foo@example.com',
# } 'password': '123'
# serializer = ExampleSerializer(data=data) }
# self.assertTrue(serializer.is_valid()) serializer = self.Serializer(data=data)
# self.assertEquals(serializer.validated_data, data) self.assertTrue(serializer.is_valid())
# self.assertEquals(serializer.data, {'email': 'foo@example.com'}) self.assertEquals(serializer.validated_data, data)
# def test_write_only_fields_meta(self): def write_only_fields_are_not_present_on_output(self):
# class ExampleSerializer(serializers.ModelSerializer): instance = {
# class Meta: 'email': 'foo@example.com',
# model = ExampleModel 'password': '123'
# fields = ('email', 'password') }
# write_only_fields = ('password',) serializer = self.Serializer(instance)
self.assertEquals(serializer.data, {'email': 'foo@example.com'})
# data = {
# 'email': 'foo@example.com',
# 'password': '123'
# }
# serializer = ExampleSerializer(data=data)
# self.assertTrue(serializer.is_valid())
# self.assertTrue(isinstance(serializer.object, ExampleModel))
# self.assertEquals(serializer.validated_data, data)
# self.assertEquals(serializer.data, {'email': 'foo@example.com'})