diff --git a/rest_framework/fields.py b/rest_framework/fields.py index c28a9695a..72d3277b1 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -673,7 +673,7 @@ class HyperlinkedIdentityField(Field): except: pass - raise ValidationError('Could not resolve URL for field using view name "%s"', view_name) + raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name) ##### Typed Fields ##### diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 1edcfa5c9..0b1d53ea8 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -18,16 +18,12 @@ class CreateModelMixin(object): serializer = self.get_serializer(data=request.DATA, files=request.FILES) if serializer.is_valid(): self.pre_save(serializer.object) - self.object = serializer.save() - headers = self.get_success_headers(serializer.data) + self.object = serializer.save(force_insert=True) + headers = self.get_response_headers(request, status.HTTP_201_CREATED, serializer=serializer) return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) - return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - - def get_success_headers(self, data): - try: - return {'Location': data['url']} - except (TypeError, KeyError): - return {} + + headers = self.get_response_headers(request, status.HTTP_400_BAD_REQUEST, serializer=serializer) + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST, headers=headers) def pre_save(self, obj): pass @@ -62,7 +58,8 @@ class ListModelMixin(object): else: serializer = self.get_serializer(self.object_list) - return Response(serializer.data) + headers = self.get_response_headers(request, serializer=serializer) + return Response(serializer.data, headers=headers) class RetrieveModelMixin(object): @@ -73,7 +70,8 @@ class RetrieveModelMixin(object): def retrieve(self, request, *args, **kwargs): self.object = self.get_object() serializer = self.get_serializer(self.object) - return Response(serializer.data) + headers = self.get_response_headers(request, serializer=serializer) + return Response(serializer.data, headers=headers) class UpdateModelMixin(object): @@ -93,11 +91,16 @@ class UpdateModelMixin(object): if serializer.is_valid(): self.pre_save(serializer.object) - self.object = serializer.save() + if created: + self.object = serializer.save(force_insert=True) + else: + self.object = serializer.save(force_update=True) status_code = created and status.HTTP_201_CREATED or status.HTTP_200_OK - return Response(serializer.data, status=status_code) + headers = self.get_response_headers(request, status_code, serializer=serializer) + return Response(serializer.data, status=status_code, headers=headers) - return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + headers = self.get_response_headers(request, status.HTTP_400_BAD_REQUEST, serializer=serializer) + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST, headers=headers) def pre_save(self, obj): """ @@ -122,4 +125,5 @@ class DestroyModelMixin(object): def destroy(self, request, *args, **kwargs): self.object = self.get_object() self.object.delete() - return Response(status=status.HTTP_204_NO_CONTENT) + headers = self.get_response_headers(request, status.HTTP_204_NO_CONTENT, object=self.object) + return Response(status=status.HTTP_204_NO_CONTENT, headers=headers) diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 1220bca10..8d5fcf5a7 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -10,7 +10,7 @@ import copy import string from django import forms from django.http.multipartparser import parse_header -from django.template import RequestContext, loader, Template +from django.template import loader, Template from django.utils import simplejson as json from rest_framework.compat import yaml from rest_framework.exceptions import ConfigurationError @@ -28,7 +28,9 @@ class BaseRenderer(object): All renderers should extend this class, setting the `media_type` and `format` attributes, and override the `.render()` method. """ - + + context_class = api_settings.DEFAULT_CONTEXT_CLASS + media_type = None format = None @@ -197,7 +199,7 @@ class TemplateHTMLRenderer(BaseRenderer): def resolve_context(self, data, request, response): if response.exception: data['status_code'] = response.status_code - return RequestContext(request, data) + return self.context_class(request, data) def get_template_names(self, response, view): if response.template_name: @@ -436,7 +438,7 @@ class BrowsableAPIRenderer(BaseRenderer): breadcrumb_list = get_breadcrumbs(request.path) template = loader.get_template(self.template) - context = RequestContext(request, { + context = self.context_class(request, { 'content': content, 'view': view, 'request': request, diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 67eafdf03..7c59418b2 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -127,7 +127,7 @@ class BaseSerializer(Field): """ return {} - def get_fields(self): + def get_all_fields(self): """ Returns the complete set of fields for the object as a dict. @@ -149,6 +149,15 @@ class BaseSerializer(Field): if key not in ret: ret[key] = val + return ret + + def get_fields(self): + """ + Returns a subset of fields specified by exclude and fields attr. + of the Option Class. + """ + ret = self.get_all_fields() + # If 'fields' is specified, use those fields, in that order. if self.opts.fields: new = SortedDict() @@ -321,11 +330,11 @@ class BaseSerializer(Field): self._data = self.to_native(self.object) return self._data - def save(self): + def save(self, **kwargs): """ Save the deserialized object and return it. - """ - self.object.save() + """ + self.object.save(**kwargs) return self.object @@ -491,11 +500,11 @@ class ModelSerializer(Serializer): self.m2m_data[field.name] = attrs.pop(field.name) return self.opts.model(**attrs) - def save(self, save_m2m=True): + def save(self, save_m2m=True, **kwargs): """ Save the deserialized object and return it. """ - self.object.save() + super(ModelSerializer, self).save(**kwargs) if getattr(self, 'm2m_data', None) and save_m2m: for accessor_name, object_list in self.m2m_data.items(): diff --git a/rest_framework/settings.py b/rest_framework/settings.py index ee24a4ad9..5de19d965 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -54,6 +54,9 @@ DEFAULTS = { 'user': None, 'anon': None, }, + + 'DEFAULT_CONTEXT_CLASS': + 'django.template.RequestContext', # Pagination 'PAGINATE_BY': None, @@ -84,6 +87,7 @@ IMPORT_STRINGS = ( 'DEFAULT_AUTHENTICATION_CLASSES', 'DEFAULT_PERMISSION_CLASSES', 'DEFAULT_THROTTLE_CLASSES', + 'DEFAULT_CONTEXT_CLASS', 'DEFAULT_CONTENT_NEGOTIATION_CLASS', 'DEFAULT_MODEL_SERIALIZER_CLASS', 'DEFAULT_PAGINATION_SERIALIZER_CLASS', diff --git a/rest_framework/tests/hyperlinkedserializers.py b/rest_framework/tests/hyperlinkedserializers.py index d7effce70..9be65992d 100644 --- a/rest_framework/tests/hyperlinkedserializers.py +++ b/rest_framework/tests/hyperlinkedserializers.py @@ -221,7 +221,6 @@ class TestCreateWithForeignKeysAndCustomSlug(TestCase): request = factory.post('/photos/', data=data) response = self.list_create_view(request).render() self.assertEqual(response.status_code, status.HTTP_201_CREATED) - self.assertNotIn('Location', response, msg='Location should only be included if there is a "url" field on the serializer') self.assertEqual(self.post.photo_set.count(), 1) self.assertEqual(self.post.photo_set.all()[0].description, 'A test photo') diff --git a/rest_framework/views.py b/rest_framework/views.py index 10bdd5a53..604d95bd3 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -13,6 +13,7 @@ from rest_framework.compat import View, apply_markdown from rest_framework.response import Response from rest_framework.request import Request from rest_framework.settings import api_settings +from rest_framework import fields def _remove_trailing_string(content, trailing): @@ -291,6 +292,23 @@ class APIView(View): # Perform content negotiation and store the accepted info on the request neg = self.perform_content_negotiation(request) request.accepted_renderer, request.accepted_media_type = neg + + def get_response_headers(self, request, status_code=status.HTTP_200_OK, serializer=None, object=None): + headers = {} + + obj = object or (serializer and serializer.object) + serializer_fields = serializer and serializer.get_all_fields() + + if status_code == status.HTTP_201_CREATED: + + if obj and hasattr(obj, 'get_absolute_url'): + headers['Location'] = obj.get_absolute_url() + elif obj and serializer_fields: + for field_name, field in serializer_fields.iteritems(): + if isinstance(field, fields.HyperlinkedIdentityField): + headers['Location'] = field.field_to_native(obj, field_name) + + return headers def finalize_response(self, request, response, *args, **kwargs): """ @@ -371,4 +389,5 @@ class APIView(View): We may as well implement this as Django will otherwise provide a less useful default implementation. """ - return Response(self.metadata(request), status=status.HTTP_200_OK) + headers = self.get_response_headers(request) + return Response(self.metadata(request), status=status.HTTP_200_OK, headers=headers)