Workin on

This commit is contained in:
Tom Christie 2014-09-05 16:29:46 +01:00
parent c1036c1753
commit d934824bff
8 changed files with 159 additions and 168 deletions

View File

@ -15,7 +15,7 @@ class APIException(Exception):
Subclasses should provide `.status_code` and `.default_detail` properties. Subclasses should provide `.status_code` and `.default_detail` properties.
""" """
status_code = status.HTTP_500_INTERNAL_SERVER_ERROR status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
default_detail = '' default_detail = 'A server error occured'
def __init__(self, detail=None): def __init__(self, detail=None):
self.detail = detail or self.default_detail self.detail = detail or self.default_detail
@ -29,6 +29,11 @@ class ParseError(APIException):
default_detail = 'Malformed request.' default_detail = 'Malformed request.'
class ValidationError(APIException):
status_code = status.HTTP_400_BAD_REQUEST
default_detail = 'Invalid data in request.'
class AuthenticationFailed(APIException): class AuthenticationFailed(APIException):
status_code = status.HTTP_401_UNAUTHORIZED status_code = status.HTTP_401_UNAUTHORIZED
default_detail = 'Incorrect authentication credentials.' default_detail = 'Incorrect authentication credentials.'
@ -54,7 +59,7 @@ class MethodNotAllowed(APIException):
class NotAcceptable(APIException): class NotAcceptable(APIException):
status_code = status.HTTP_406_NOT_ACCEPTABLE status_code = status.HTTP_406_NOT_ACCEPTABLE
default_detail = "Could not satisfy the request's Accept header" default_detail = "Could not satisfy the request Accept header"
def __init__(self, detail=None, available_renderers=None): def __init__(self, detail=None, available_renderers=None):
self.detail = detail or self.default_detail self.detail = detail or self.default_detail

View File

@ -1,3 +1,4 @@
from rest_framework.exceptions import ValidationError
from rest_framework.utils import html from rest_framework.utils import html
import inspect import inspect
@ -59,10 +60,6 @@ def set_value(dictionary, keys, value):
dictionary[keys[-1]] = value dictionary[keys[-1]] = value
class ValidationError(Exception):
pass
class SkipField(Exception): class SkipField(Exception):
pass pass
@ -204,6 +201,22 @@ class Field(object):
msg = self._MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key) msg = self._MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key)
raise AssertionError(msg) raise AssertionError(msg)
def __new__(cls, *args, **kwargs):
instance = super(Field, cls).__new__(cls)
instance._args = args
instance._kwargs = kwargs
return instance
def __repr__(self):
arg_string = ', '.join([repr(val) for val in self._args])
kwarg_string = ', '.join([
'%s=%s' % (key, repr(val)) for key, val in self._kwargs.items()
])
if arg_string and kwarg_string:
arg_string += ', '
class_name = self.__class__.__name__
return "%s(%s%s)" % (class_name, arg_string, kwarg_string)
class BooleanField(Field): class BooleanField(Field):
MESSAGES = { MESSAGES = {
@ -308,6 +321,11 @@ class IntegerField(Field):
'invalid_integer': 'A valid integer is required.' 'invalid_integer': 'A valid integer is required.'
} }
def __init__(self, **kwargs):
self.max_value = kwargs.pop('max_value')
self.min_value = kwargs.pop('min_value')
super(CharField, self).__init__(**kwargs)
def to_native(self, data): def to_native(self, data):
try: try:
data = int(str(data)) data = int(str(data))

View File

@ -27,7 +27,7 @@ def strict_positive_int(integer_string, cutoff=None):
def get_object_or_404(queryset, *filter_args, **filter_kwargs): def get_object_or_404(queryset, *filter_args, **filter_kwargs):
""" """
Same as Django's standard shortcut, but make sure to raise 404 Same as Django's standard shortcut, but make sure to also raise 404
if the filter_kwargs don't match the required types. if the filter_kwargs don't match the required types.
""" """
try: try:
@ -249,34 +249,6 @@ class GenericAPIView(views.APIView):
# #
# The are not called by GenericAPIView directly, # The are not called by GenericAPIView directly,
# but are used by the mixin methods. # but are used by the mixin methods.
def pre_save(self, obj):
"""
Placeholder method for calling before saving an object.
May be used to set attributes on the object that are implicit
in either the request, or the url.
"""
pass
def post_save(self, obj, created=False):
"""
Placeholder method for calling after saving an object.
"""
pass
def pre_delete(self, obj):
"""
Placeholder method for calling before deleting an object.
"""
pass
def post_delete(self, obj):
"""
Placeholder method for calling after deleting an object.
"""
pass
def metadata(self, request): def metadata(self, request):
""" """
Return a dictionary of metadata about the view. Return a dictionary of metadata about the view.

View File

@ -19,14 +19,10 @@ class CreateModelMixin(object):
""" """
def create(self, request, *args, **kwargs): def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.DATA) serializer = self.get_serializer(data=request.DATA)
serializer.is_valid(raise_exception=True)
if serializer.is_valid(): serializer.save()
self.object = serializer.save()
headers = self.get_success_headers(serializer.data) headers = self.get_success_headers(serializer.data)
return Response(serializer.data, status=status.HTTP_201_CREATED, return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
headers=headers)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
def get_success_headers(self, data): def get_success_headers(self, data):
try: try:
@ -40,15 +36,12 @@ class ListModelMixin(object):
List a queryset. List a queryset.
""" """
def list(self, request, *args, **kwargs): def list(self, request, *args, **kwargs):
self.object_list = self.filter_queryset(self.get_queryset()) instance = self.filter_queryset(self.get_queryset())
page = self.paginate_queryset(instance)
# Switch between paginated or standard style responses
page = self.paginate_queryset(self.object_list)
if page is not None: if page is not None:
serializer = self.get_pagination_serializer(page) serializer = self.get_pagination_serializer(page)
else: else:
serializer = self.get_serializer(self.object_list, many=True) serializer = self.get_serializer(instance, many=True)
return Response(serializer.data) return Response(serializer.data)
@ -57,8 +50,8 @@ class RetrieveModelMixin(object):
Retrieve a model instance. Retrieve a model instance.
""" """
def retrieve(self, request, *args, **kwargs): def retrieve(self, request, *args, **kwargs):
self.object = self.get_object() instance = self.get_object()
serializer = self.get_serializer(self.object) serializer = self.get_serializer(instance)
return Response(serializer.data) return Response(serializer.data)
@ -68,22 +61,52 @@ class UpdateModelMixin(object):
""" """
def update(self, request, *args, **kwargs): def update(self, request, *args, **kwargs):
partial = kwargs.pop('partial', False) partial = kwargs.pop('partial', False)
self.object = self.get_object_or_none() instance = self.get_object()
serializer = self.get_serializer(instance, data=request.DATA, partial=partial)
serializer.is_valid(raise_exception=True)
serializer.save()
return Response(serializer.data)
serializer = self.get_serializer(self.object, data=request.DATA, partial=partial) def partial_update(self, request, *args, **kwargs):
kwargs['partial'] = True
return self.update(request, *args, **kwargs)
if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
if self.object is None: class DestroyModelMixin(object):
"""
Destroy a model instance.
"""
def destroy(self, request, *args, **kwargs):
instance = self.get_object()
instance.delete()
return Response(status=status.HTTP_204_NO_CONTENT)
# The AllowPUTAsCreateMixin was previously the default behaviour
# for PUT requests. This has now been removed and must be *explictly*
# included if it is the behavior that you want.
# For more info see: ...
class AllowPUTAsCreateMixin(object):
"""
The following mixin class may be used in order to support PUT-as-create
behavior for incoming requests.
"""
def update(self, request, *args, **kwargs):
partial = kwargs.pop('partial', False)
instance = self.get_object_or_none()
serializer = self.get_serializer(instance, data=request.DATA, partial=partial)
serializer.is_valid(raise_exception=True)
if instance is None:
lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
lookup_value = self.kwargs[lookup_url_kwarg] lookup_value = self.kwargs[lookup_url_kwarg]
extras = {self.lookup_field: lookup_value} extras = {self.lookup_field: lookup_value}
self.object = serializer.save(extras=extras) serializer.save(extras=extras)
return Response(serializer.data, status=status.HTTP_201_CREATED) return Response(serializer.data, status=status.HTTP_201_CREATED)
self.object = serializer.save() serializer.save()
return Response(serializer.data, status=status.HTTP_200_OK) return Response(serializer.data)
def partial_update(self, request, *args, **kwargs): def partial_update(self, request, *args, **kwargs):
kwargs['partial'] = True kwargs['partial'] = True
@ -103,15 +126,3 @@ class UpdateModelMixin(object):
# PATCH requests where the object does not exist should still # PATCH requests where the object does not exist should still
# return a 404 response. # return a 404 response.
raise raise
class DestroyModelMixin(object):
"""
Destroy a model instance.
"""
def destroy(self, request, *args, **kwargs):
obj = self.get_object()
self.pre_delete(obj)
obj.delete()
self.post_delete(obj)
return Response(status=status.HTTP_204_NO_CONTENT)

View File

@ -13,7 +13,8 @@ response content is handled by parsers and renderers.
from django.db import models from django.db import models
from django.utils import six from django.utils import six
from collections import namedtuple, OrderedDict from collections import namedtuple, OrderedDict
from rest_framework.fields import empty, set_value, Field, SkipField, ValidationError from rest_framework.exceptions import ValidationError
from rest_framework.fields import empty, set_value, Field, SkipField
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.utils import html from rest_framework.utils import html
import copy import copy
@ -34,43 +35,53 @@ FieldResult = namedtuple('FieldResult', ['field', 'value', 'error'])
class BaseSerializer(Field): class BaseSerializer(Field):
"""
The BaseSerializer class provides a minimal class which may be used
for writing custom serializer implementations.
"""
def __init__(self, instance=None, data=None, **kwargs): def __init__(self, instance=None, data=None, **kwargs):
super(BaseSerializer, self).__init__(**kwargs) super(BaseSerializer, self).__init__(**kwargs)
self.instance = instance self.instance = instance
self._initial_data = data self._initial_data = data
def to_native(self, data): def to_native(self, data):
raise NotImplementedError() raise NotImplementedError('`to_native()` must be implemented.')
def to_primative(self, instance): def to_primative(self, instance):
raise NotImplementedError() raise NotImplementedError('`to_primative()` must be implemented.')
def update(self, instance): def update(self, instance, attrs):
raise NotImplementedError() raise NotImplementedError('`update()` must be implemented.')
def create(self): def create(self, attrs):
raise NotImplementedError() raise NotImplementedError('`create()` must be implemented.')
def save(self, extras=None): def save(self, extras=None):
if extras is not None: if extras is not None:
self._validated_data.update(extras) self.validated_data.update(extras)
if self.instance is not None: if self.instance is not None:
self.update(self.instance) self.update(self.instance, self._validated_data)
else: else:
self.instance = self.create() self.instance = self.create(self._validated_data)
return self.instance return self.instance
def is_valid(self): def is_valid(self, raise_exception=False):
if not hasattr(self, '_validated_data'):
try: try:
self._validated_data = self.to_native(self._initial_data) self._validated_data = self.to_native(self._initial_data)
except ValidationError as exc: except ValidationError as exc:
self._validated_data = {} self._validated_data = {}
self._errors = exc.args[0] self._errors = exc.detail
return False else:
self._errors = {} self._errors = {}
return True
if self._errors and raise_exception:
raise ValidationError(self._errors)
return not bool(self._errors)
@property @property
def data(self): def data(self):
@ -184,14 +195,20 @@ class Serializer(BaseSerializer):
""" """
Dict of native values <- Dict of primitive datatypes. Dict of native values <- Dict of primitive datatypes.
""" """
if not isinstance(data, dict):
raise ValidationError({'non_field_errors': ['Invalid data']})
ret = {} ret = {}
errors = {} errors = {}
fields = [field for field in self.fields.values() if not field.read_only] fields = [field for field in self.fields.values() if not field.read_only]
for field in fields: for field in fields:
validate_method = getattr(self, 'validate_' + field.field_name, None)
primitive_value = field.get_value(data) primitive_value = field.get_value(data)
try: try:
validated_value = field.validate(primitive_value) validated_value = field.validate(primitive_value)
if validate_method is not None:
validated_value = validate_method(validated_value)
except ValidationError as exc: except ValidationError as exc:
errors[field.field_name] = str(exc) errors[field.field_name] = str(exc)
except SkipField: except SkipField:
@ -202,6 +219,7 @@ class Serializer(BaseSerializer):
if errors: if errors:
raise ValidationError(errors) raise ValidationError(errors)
# TODO: 'Non field errors'
return self.validate(ret) return self.validate(ret)
def to_primative(self, instance): def to_primative(self, instance):
@ -340,12 +358,12 @@ class ModelSerializer(Serializer):
self.opts = self._options_class(self.Meta) self.opts = self._options_class(self.Meta)
super(ModelSerializer, self).__init__(*args, **kwargs) super(ModelSerializer, self).__init__(*args, **kwargs)
def create(self): def create(self, attrs):
ModelClass = self.opts.model ModelClass = self.opts.model
return ModelClass.objects.create(**self.validated_data) return ModelClass.objects.create(**attrs)
def update(self, obj): def update(self, obj, attrs):
for attr, value in self.validated_data.items(): for attr, value in attrs.items():
setattr(obj, attr, value) setattr(obj, attr, value)
obj.save() obj.save()

View File

@ -360,18 +360,15 @@ class TestInstanceView(TestCase):
def test_put_to_deleted_instance(self): def test_put_to_deleted_instance(self):
""" """
PUT requests to RetrieveUpdateDestroyAPIView should create an object PUT requests to RetrieveUpdateDestroyAPIView should return 404 if
if it does not currently exist. an object does not currently exist.
""" """
self.objects.get(id=1).delete() self.objects.get(id=1).delete()
data = {'text': 'foobar'} data = {'text': 'foobar'}
request = factory.put('/1', data, format='json') request = factory.put('/1', data, format='json')
with self.assertNumQueries(2): with self.assertNumQueries(1):
response = self.view(request, pk=1).render() response = self.view(request, pk=1).render()
self.assertEqual(response.status_code, status.HTTP_201_CREATED) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertEqual(response.data, {'id': 1, 'text': 'foobar'})
updated = self.objects.get(id=1)
self.assertEqual(updated.text, 'foobar')
def test_put_to_filtered_out_instance(self): def test_put_to_filtered_out_instance(self):
""" """
@ -382,35 +379,7 @@ class TestInstanceView(TestCase):
filtered_out_pk = BasicModel.objects.filter(text='filtered out')[0].pk filtered_out_pk = BasicModel.objects.filter(text='filtered out')[0].pk
request = factory.put('/{0}'.format(filtered_out_pk), data, format='json') request = factory.put('/{0}'.format(filtered_out_pk), data, format='json')
response = self.view(request, pk=filtered_out_pk).render() response = self.view(request, pk=filtered_out_pk).render()
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
def test_put_as_create_on_id_based_url(self):
"""
PUT requests to RetrieveUpdateDestroyAPIView should create an object
at the requested url if it doesn't exist.
"""
data = {'text': 'foobar'}
# pk fields can not be created on demand, only the database can set the pk for a new object
request = factory.put('/5', data, format='json')
with self.assertNumQueries(2):
response = self.view(request, pk=5).render()
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
new_obj = self.objects.get(pk=5)
self.assertEqual(new_obj.text, 'foobar')
def test_put_as_create_on_slug_based_url(self):
"""
PUT requests to RetrieveUpdateDestroyAPIView should create an object
at the requested url if possible, else return HTTP_403_FORBIDDEN error-response.
"""
data = {'text': 'foobar'}
request = factory.put('/test_slug', data, format='json')
with self.assertNumQueries(2):
response = self.slug_based_view(request, slug='test_slug').render()
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(response.data, {'slug': 'test_slug', 'text': 'foobar'})
new_obj = SlugBasedModel.objects.get(slug='test_slug')
self.assertEqual(new_obj.text, 'foobar')
def test_patch_cannot_create_an_object(self): def test_patch_cannot_create_an_object(self):
""" """

View File

@ -48,11 +48,10 @@ class ShouldValidateModel(models.Model):
class ShouldValidateModelSerializer(serializers.ModelSerializer): class ShouldValidateModelSerializer(serializers.ModelSerializer):
renamed = serializers.CharField(source='should_validate_field', required=False) renamed = serializers.CharField(source='should_validate_field', required=False)
def validate_renamed(self, attrs, source): def validate_renamed(self, value):
value = attrs[source]
if len(value) < 3: if len(value) < 3:
raise serializers.ValidationError('Minimum 3 characters.') raise serializers.ValidationError('Minimum 3 characters.')
return attrs return value
class Meta: class Meta:
model = ShouldValidateModel model = ShouldValidateModel

View File

@ -1,42 +1,41 @@
from django.db import models # from django.db import models
from django.test import TestCase # from django.test import TestCase
from rest_framework import serializers # from rest_framework import serializers
class ExampleModel(models.Model): # class ExampleModel(models.Model):
email = models.EmailField(max_length=100) # email = models.EmailField(max_length=100)
password = models.CharField(max_length=100) # password = models.CharField(max_length=100)
class WriteOnlyFieldTests(TestCase): # class WriteOnlyFieldTests(TestCase):
def test_write_only_fields(self): # def test_write_only_fields(self):
class ExampleSerializer(serializers.Serializer): # class ExampleSerializer(serializers.Serializer):
email = serializers.EmailField() # email = serializers.EmailField()
password = serializers.CharField(write_only=True) # password = serializers.CharField(write_only=True)
data = { # data = {
'email': 'foo@example.com', # 'email': 'foo@example.com',
'password': '123' # 'password': '123'
} # }
serializer = ExampleSerializer(data=data) # serializer = ExampleSerializer(data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
self.assertEquals(serializer.object, data) # self.assertEquals(serializer.validated_data, data)
self.assertEquals(serializer.data, {'email': 'foo@example.com'}) # self.assertEquals(serializer.data, {'email': 'foo@example.com'})
def test_write_only_fields_meta(self): # def test_write_only_fields_meta(self):
class ExampleSerializer(serializers.ModelSerializer): # class ExampleSerializer(serializers.ModelSerializer):
class Meta: # class Meta:
model = ExampleModel # model = ExampleModel
fields = ('email', 'password') # fields = ('email', 'password')
write_only_fields = ('password',) # write_only_fields = ('password',)
data = { # data = {
'email': 'foo@example.com', # 'email': 'foo@example.com',
'password': '123' # 'password': '123'
} # }
serializer = ExampleSerializer(data=data) # serializer = ExampleSerializer(data=data)
self.assertTrue(serializer.is_valid()) # self.assertTrue(serializer.is_valid())
self.assertTrue(isinstance(serializer.object, ExampleModel)) # self.assertTrue(isinstance(serializer.object, ExampleModel))
self.assertEquals(serializer.object.email, data['email']) # self.assertEquals(serializer.validated_data, data)
self.assertEquals(serializer.object.password, data['password']) # self.assertEquals(serializer.data, {'email': 'foo@example.com'})
self.assertEquals(serializer.data, {'email': 'foo@example.com'})