diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 23e241c03..0b1d53ea8 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -19,8 +19,11 @@ class CreateModelMixin(object): if serializer.is_valid(): self.pre_save(serializer.object) self.object = serializer.save(force_insert=True) - return Response(serializer.data, status=status.HTTP_201_CREATED, headers=serializer.headers) - return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST, headers=serializer.headers) + headers = self.get_response_headers(request, status.HTTP_201_CREATED, serializer=serializer) + return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) + + headers = self.get_response_headers(request, status.HTTP_400_BAD_REQUEST, serializer=serializer) + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST, headers=headers) def pre_save(self, obj): pass @@ -55,7 +58,8 @@ class ListModelMixin(object): else: serializer = self.get_serializer(self.object_list) - return Response(serializer.data, headers=serializer.headers) + headers = self.get_response_headers(request, serializer=serializer) + return Response(serializer.data, headers=headers) class RetrieveModelMixin(object): @@ -66,7 +70,8 @@ class RetrieveModelMixin(object): def retrieve(self, request, *args, **kwargs): self.object = self.get_object() serializer = self.get_serializer(self.object) - return Response(serializer.data, headers=serializer.headers) + headers = self.get_response_headers(request, serializer=serializer) + return Response(serializer.data, headers=headers) class UpdateModelMixin(object): @@ -91,9 +96,11 @@ class UpdateModelMixin(object): else: self.object = serializer.save(force_update=True) status_code = created and status.HTTP_201_CREATED or status.HTTP_200_OK - return Response(serializer.data, status=status_code, headers=serializer.headers) + headers = self.get_response_headers(request, status_code, serializer=serializer) + return Response(serializer.data, status=status_code, headers=headers) - return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST, headers=serializer.headers) + headers = self.get_response_headers(request, status.HTTP_400_BAD_REQUEST, serializer=serializer) + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST, headers=headers) def pre_save(self, obj): """ @@ -118,4 +125,5 @@ class DestroyModelMixin(object): def destroy(self, request, *args, **kwargs): self.object = self.get_object() self.object.delete() - return Response(status=status.HTTP_204_NO_CONTENT) + headers = self.get_response_headers(request, status.HTTP_204_NO_CONTENT, object=self.object) + return Response(status=status.HTTP_204_NO_CONTENT, headers=headers) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 61bb54d3a..1bba5c9ed 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -107,8 +107,7 @@ class BaseSerializer(Field): self._data = None self._files = None - self._errors = None - self._headers = {} + self._errors = None ##### # Methods to determine which fields to use when (de)serializing objects. @@ -119,7 +118,7 @@ class BaseSerializer(Field): """ return {} - def get_fields(self): + def get_all_fields(self): """ Returns the complete set of fields for the object as a dict. @@ -141,6 +140,15 @@ class BaseSerializer(Field): if key not in ret: ret[key] = val + return ret + + def get_fields(self): + """ + Returns a subset of fields specified by exclude and fields attr. + of the Option Class. + """ + ret = self.get_all_fields() + # If 'fields' is specified, use those fields, in that order. if self.opts.fields: new = SortedDict() @@ -308,44 +316,9 @@ class BaseSerializer(Field): def save(self, **kwargs): """ Save the deserialized object and return it. - """ - pk_val = self.object._get_pk_val(self.object.__class__._meta) - pk_set = pk_val is not None - - if ((pk_set) and - ((('force_update' in kwargs) or ('update_fields' in kwargs)) or - ('force_insert' not in kwargs and self.object.__class__.objects.filter(pk=pk_val).exists()))): - created = False - else: - created = True - - self.object.save(**kwargs) - - if created: - self.set_location_header() - + """ + self.object.save(**kwargs) return self.object - - def _generate_headers(self): - return {} - - @property - def headers(self): - ret = self._generate_headers() - ret.update(self._headers) - return ret - - def set_location_header(self): - if hasattr(self.object, 'get_absolute_url'): - self._headers['Location'] = self.object.get_absolute_url() - return True - else: - for field_name, field in self.fields.iteritems(): - if isinstance(field, HyperlinkedIdentityField): - self._headers['Location'] = field.field_to_native(self.object, field_name) - return True - - return False class Serializer(BaseSerializer): diff --git a/rest_framework/views.py b/rest_framework/views.py index 10bdd5a53..604d95bd3 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -13,6 +13,7 @@ from rest_framework.compat import View, apply_markdown from rest_framework.response import Response from rest_framework.request import Request from rest_framework.settings import api_settings +from rest_framework import fields def _remove_trailing_string(content, trailing): @@ -291,6 +292,23 @@ class APIView(View): # Perform content negotiation and store the accepted info on the request neg = self.perform_content_negotiation(request) request.accepted_renderer, request.accepted_media_type = neg + + def get_response_headers(self, request, status_code=status.HTTP_200_OK, serializer=None, object=None): + headers = {} + + obj = object or (serializer and serializer.object) + serializer_fields = serializer and serializer.get_all_fields() + + if status_code == status.HTTP_201_CREATED: + + if obj and hasattr(obj, 'get_absolute_url'): + headers['Location'] = obj.get_absolute_url() + elif obj and serializer_fields: + for field_name, field in serializer_fields.iteritems(): + if isinstance(field, fields.HyperlinkedIdentityField): + headers['Location'] = field.field_to_native(obj, field_name) + + return headers def finalize_response(self, request, response, *args, **kwargs): """ @@ -371,4 +389,5 @@ class APIView(View): We may as well implement this as Django will otherwise provide a less useful default implementation. """ - return Response(self.metadata(request), status=status.HTTP_200_OK) + headers = self.get_response_headers(request) + return Response(self.metadata(request), status=status.HTTP_200_OK, headers=headers)