diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 1edcfa5c9..23e241c03 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -18,16 +18,9 @@ class CreateModelMixin(object): serializer = self.get_serializer(data=request.DATA, files=request.FILES) if serializer.is_valid(): self.pre_save(serializer.object) - self.object = serializer.save() - headers = self.get_success_headers(serializer.data) - return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) - return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - - def get_success_headers(self, data): - try: - return {'Location': data['url']} - except (TypeError, KeyError): - return {} + self.object = serializer.save(force_insert=True) + return Response(serializer.data, status=status.HTTP_201_CREATED, headers=serializer.headers) + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST, headers=serializer.headers) def pre_save(self, obj): pass @@ -62,7 +55,7 @@ class ListModelMixin(object): else: serializer = self.get_serializer(self.object_list) - return Response(serializer.data) + return Response(serializer.data, headers=serializer.headers) class RetrieveModelMixin(object): @@ -73,7 +66,7 @@ class RetrieveModelMixin(object): def retrieve(self, request, *args, **kwargs): self.object = self.get_object() serializer = self.get_serializer(self.object) - return Response(serializer.data) + return Response(serializer.data, headers=serializer.headers) class UpdateModelMixin(object): @@ -93,11 +86,14 @@ class UpdateModelMixin(object): if serializer.is_valid(): self.pre_save(serializer.object) - self.object = serializer.save() + if created: + self.object = serializer.save(force_insert=True) + else: + self.object = serializer.save(force_update=True) status_code = created and status.HTTP_201_CREATED or status.HTTP_200_OK - return Response(serializer.data, status=status_code) + return Response(serializer.data, status=status_code, headers=serializer.headers) - return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST, headers=serializer.headers) def pre_save(self, obj): """ diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 4519ab053..672d65c6c 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -108,6 +108,8 @@ class BaseSerializer(Field): self._data = None self._files = None self._errors = None + self._headers = {} + ##### # Methods to determine which fields to use when (de)serializing objects. @@ -304,12 +306,47 @@ class BaseSerializer(Field): self._data = self.to_native(self.object) return self._data - def save(self): + def save(self, **kwargs): """ Save the deserialized object and return it. """ - self.object.save() + pk_val = self.object._get_pk_val(self.object.__class__._meta) + pk_set = pk_val is not None + + if ((pk_set) and + ((('force_update' in kwargs) or ('update_fields' in kwargs)) or + ('force_insert' not in kwargs and self.object.__class__.objects.filter(pk=pk_val).exists()))): + created = False + else: + created = True + + self.object.save(**kwargs) + + if created: + self.set_location_header() + return self.object + + def generate_header(self): + return {} + + @property + def headers(self): + #self._headers.update(self.generate_header()) + return self._headers + + def set_location_header(self): + self._headers['Location'] = 'x' + if hasattr(self.object, 'get_absolute_url'): + self._headers['Location'] = self.object.get_absolute_url() + return True + else: + for field_name, field in self.fields.iteritems(): + if isinstance(field, HyperlinkedIdentityField): + self._headers['Location'] = field.field_to_native(self.object, field_name) + return True + + return False class Serializer(BaseSerializer): @@ -474,11 +511,11 @@ class ModelSerializer(Serializer): self.m2m_data[field.name] = attrs.pop(field.name) return self.opts.model(**attrs) - def save(self, save_m2m=True): + def save(self, save_m2m=True, **kwargs): """ Save the deserialized object and return it. """ - self.object.save() + super(ModelSerializer, self).save(**kwargs) if getattr(self, 'm2m_data', None) and save_m2m: for accessor_name, object_list in self.m2m_data.items(): @@ -539,3 +576,10 @@ class HyperlinkedModelSerializer(ModelSerializer): if to_many: return ManyHyperlinkedRelatedField(**kwargs) return HyperlinkedRelatedField(**kwargs) + + def set_location_header(self): + if not super(HyperlinkedModelSerializer, self).set_location_header(): + self._headers['Location'] = self.data['url'] + return True + + return True \ No newline at end of file diff --git a/rest_framework/tests/hyperlinkedserializers.py b/rest_framework/tests/hyperlinkedserializers.py index d7effce70..9be65992d 100644 --- a/rest_framework/tests/hyperlinkedserializers.py +++ b/rest_framework/tests/hyperlinkedserializers.py @@ -221,7 +221,6 @@ class TestCreateWithForeignKeysAndCustomSlug(TestCase): request = factory.post('/photos/', data=data) response = self.list_create_view(request).render() self.assertEqual(response.status_code, status.HTTP_201_CREATED) - self.assertNotIn('Location', response, msg='Location should only be included if there is a "url" field on the serializer') self.assertEqual(self.post.photo_set.count(), 1) self.assertEqual(self.post.photo_set.all()[0].description, 'A test photo')