This commit is contained in:
Ludwig Kraatz 2012-12-07 01:53:55 -08:00
commit 986b311a08
7 changed files with 65 additions and 28 deletions

View File

@ -673,7 +673,7 @@ class HyperlinkedIdentityField(Field):
except:
pass
raise ValidationError('Could not resolve URL for field using view name "%s"', view_name)
raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name)
##### Typed Fields #####

View File

@ -18,16 +18,12 @@ class CreateModelMixin(object):
serializer = self.get_serializer(data=request.DATA, files=request.FILES)
if serializer.is_valid():
self.pre_save(serializer.object)
self.object = serializer.save()
headers = self.get_success_headers(serializer.data)
self.object = serializer.save(force_insert=True)
headers = self.get_response_headers(request, status.HTTP_201_CREATED, serializer=serializer)
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
def get_success_headers(self, data):
try:
return {'Location': data['url']}
except (TypeError, KeyError):
return {}
headers = self.get_response_headers(request, status.HTTP_400_BAD_REQUEST, serializer=serializer)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST, headers=headers)
def pre_save(self, obj):
pass
@ -62,7 +58,8 @@ class ListModelMixin(object):
else:
serializer = self.get_serializer(self.object_list)
return Response(serializer.data)
headers = self.get_response_headers(request, serializer=serializer)
return Response(serializer.data, headers=headers)
class RetrieveModelMixin(object):
@ -73,7 +70,8 @@ class RetrieveModelMixin(object):
def retrieve(self, request, *args, **kwargs):
self.object = self.get_object()
serializer = self.get_serializer(self.object)
return Response(serializer.data)
headers = self.get_response_headers(request, serializer=serializer)
return Response(serializer.data, headers=headers)
class UpdateModelMixin(object):
@ -93,11 +91,16 @@ class UpdateModelMixin(object):
if serializer.is_valid():
self.pre_save(serializer.object)
self.object = serializer.save()
if created:
self.object = serializer.save(force_insert=True)
else:
self.object = serializer.save(force_update=True)
status_code = created and status.HTTP_201_CREATED or status.HTTP_200_OK
return Response(serializer.data, status=status_code)
headers = self.get_response_headers(request, status_code, serializer=serializer)
return Response(serializer.data, status=status_code, headers=headers)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
headers = self.get_response_headers(request, status.HTTP_400_BAD_REQUEST, serializer=serializer)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST, headers=headers)
def pre_save(self, obj):
"""
@ -122,4 +125,5 @@ class DestroyModelMixin(object):
def destroy(self, request, *args, **kwargs):
self.object = self.get_object()
self.object.delete()
return Response(status=status.HTTP_204_NO_CONTENT)
headers = self.get_response_headers(request, status.HTTP_204_NO_CONTENT, object=self.object)
return Response(status=status.HTTP_204_NO_CONTENT, headers=headers)

View File

@ -10,7 +10,7 @@ import copy
import string
from django import forms
from django.http.multipartparser import parse_header
from django.template import RequestContext, loader, Template
from django.template import loader, Template
from django.utils import simplejson as json
from rest_framework.compat import yaml
from rest_framework.exceptions import ConfigurationError
@ -28,7 +28,9 @@ class BaseRenderer(object):
All renderers should extend this class, setting the `media_type`
and `format` attributes, and override the `.render()` method.
"""
context_class = api_settings.DEFAULT_CONTEXT_CLASS
media_type = None
format = None
@ -197,7 +199,7 @@ class TemplateHTMLRenderer(BaseRenderer):
def resolve_context(self, data, request, response):
if response.exception:
data['status_code'] = response.status_code
return RequestContext(request, data)
return self.context_class(request, data)
def get_template_names(self, response, view):
if response.template_name:
@ -436,7 +438,7 @@ class BrowsableAPIRenderer(BaseRenderer):
breadcrumb_list = get_breadcrumbs(request.path)
template = loader.get_template(self.template)
context = RequestContext(request, {
context = self.context_class(request, {
'content': content,
'view': view,
'request': request,

View File

@ -127,7 +127,7 @@ class BaseSerializer(Field):
"""
return {}
def get_fields(self):
def get_all_fields(self):
"""
Returns the complete set of fields for the object as a dict.
@ -149,6 +149,15 @@ class BaseSerializer(Field):
if key not in ret:
ret[key] = val
return ret
def get_fields(self):
"""
Returns a subset of fields specified by exclude and fields attr.
of the Option Class.
"""
ret = self.get_all_fields()
# If 'fields' is specified, use those fields, in that order.
if self.opts.fields:
new = SortedDict()
@ -321,11 +330,11 @@ class BaseSerializer(Field):
self._data = self.to_native(self.object)
return self._data
def save(self):
def save(self, **kwargs):
"""
Save the deserialized object and return it.
"""
self.object.save()
"""
self.object.save(**kwargs)
return self.object
@ -491,11 +500,11 @@ class ModelSerializer(Serializer):
self.m2m_data[field.name] = attrs.pop(field.name)
return self.opts.model(**attrs)
def save(self, save_m2m=True):
def save(self, save_m2m=True, **kwargs):
"""
Save the deserialized object and return it.
"""
self.object.save()
super(ModelSerializer, self).save(**kwargs)
if getattr(self, 'm2m_data', None) and save_m2m:
for accessor_name, object_list in self.m2m_data.items():

View File

@ -54,6 +54,9 @@ DEFAULTS = {
'user': None,
'anon': None,
},
'DEFAULT_CONTEXT_CLASS':
'django.template.RequestContext',
# Pagination
'PAGINATE_BY': None,
@ -84,6 +87,7 @@ IMPORT_STRINGS = (
'DEFAULT_AUTHENTICATION_CLASSES',
'DEFAULT_PERMISSION_CLASSES',
'DEFAULT_THROTTLE_CLASSES',
'DEFAULT_CONTEXT_CLASS',
'DEFAULT_CONTENT_NEGOTIATION_CLASS',
'DEFAULT_MODEL_SERIALIZER_CLASS',
'DEFAULT_PAGINATION_SERIALIZER_CLASS',

View File

@ -221,7 +221,6 @@ class TestCreateWithForeignKeysAndCustomSlug(TestCase):
request = factory.post('/photos/', data=data)
response = self.list_create_view(request).render()
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertNotIn('Location', response, msg='Location should only be included if there is a "url" field on the serializer')
self.assertEqual(self.post.photo_set.count(), 1)
self.assertEqual(self.post.photo_set.all()[0].description, 'A test photo')

View File

@ -13,6 +13,7 @@ from rest_framework.compat import View, apply_markdown
from rest_framework.response import Response
from rest_framework.request import Request
from rest_framework.settings import api_settings
from rest_framework import fields
def _remove_trailing_string(content, trailing):
@ -291,6 +292,23 @@ class APIView(View):
# Perform content negotiation and store the accepted info on the request
neg = self.perform_content_negotiation(request)
request.accepted_renderer, request.accepted_media_type = neg
def get_response_headers(self, request, status_code=status.HTTP_200_OK, serializer=None, object=None):
headers = {}
obj = object or (serializer and serializer.object)
serializer_fields = serializer and serializer.get_all_fields()
if status_code == status.HTTP_201_CREATED:
if obj and hasattr(obj, 'get_absolute_url'):
headers['Location'] = obj.get_absolute_url()
elif obj and serializer_fields:
for field_name, field in serializer_fields.iteritems():
if isinstance(field, fields.HyperlinkedIdentityField):
headers['Location'] = field.field_to_native(obj, field_name)
return headers
def finalize_response(self, request, response, *args, **kwargs):
"""
@ -371,4 +389,5 @@ class APIView(View):
We may as well implement this as Django will otherwise provide
a less useful default implementation.
"""
return Response(self.metadata(request), status=status.HTTP_200_OK)
headers = self.get_response_headers(request)
return Response(self.metadata(request), status=status.HTTP_200_OK, headers=headers)