diff --git a/docs/topics/credits.md b/docs/topics/credits.md index 22d08df7a..8e71c937a 100644 --- a/docs/topics/credits.md +++ b/docs/topics/credits.md @@ -59,6 +59,7 @@ The following people have helped make REST framework great. * Toni Michel - [tonimichel] * Ben Konrath - [benkonrath] * Marc Aymerich - [glic3rinu] +* Ludwig Kraatz - [ludwigkraatz] Many thanks to everyone who's contributed to the project. @@ -153,3 +154,4 @@ To contact the author directly: [tonimichel]: https://github.com/tonimichel [benkonrath]: https://github.com/benkonrath [glic3rinu]: https://github.com/glic3rinu +[ludwigkraatz]: https://github.com/ludwigkraatz diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index c3625a88e..cd104a7c9 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -19,9 +19,16 @@ class CreateModelMixin(object): if serializer.is_valid(): self.pre_save(serializer.object) self.object = serializer.save() - return Response(serializer.data, status=status.HTTP_201_CREATED) + headers = self.get_success_headers(serializer.data) + return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - + + def get_success_headers(self, data): + if 'url' in data: + return {'Location': data.get('url')} + else: + return {} + def pre_save(self, obj): pass diff --git a/rest_framework/response.py b/rest_framework/response.py index 0bd6c65d7..be78c43ae 100644 --- a/rest_framework/response.py +++ b/rest_framework/response.py @@ -20,9 +20,12 @@ class Response(SimpleTemplateResponse): """ super(Response, self).__init__(None, status=status) self.data = data - self.headers = headers and headers[:] or [] self.template_name = template_name self.exception = exception + + if headers: + for name,value in headers.iteritems(): + self[name] = value @property def rendered_content(self): diff --git a/rest_framework/tests/hyperlinkedserializers.py b/rest_framework/tests/hyperlinkedserializers.py index 5ab850afa..d7effce70 100644 --- a/rest_framework/tests/hyperlinkedserializers.py +++ b/rest_framework/tests/hyperlinkedserializers.py @@ -8,12 +8,13 @@ factory = RequestFactory() class BlogPostCommentSerializer(serializers.ModelSerializer): + url = serializers.HyperlinkedIdentityField(view_name='blogpostcomment-detail') text = serializers.CharField() blog_post_url = serializers.HyperlinkedRelatedField(source='blog_post', view_name='blogpost-detail') class Meta: model = BlogPostComment - fields = ('text', 'blog_post_url') + fields = ('text', 'blog_post_url', 'url') class PhotoSerializer(serializers.Serializer): @@ -53,6 +54,9 @@ class BlogPostCommentListCreate(generics.ListCreateAPIView): model = BlogPostComment serializer_class = BlogPostCommentSerializer +class BlogPostCommentDetail(generics.RetrieveAPIView): + model = BlogPostComment + serializer_class = BlogPostCommentSerializer class BlogPostDetail(generics.RetrieveAPIView): model = BlogPost @@ -80,6 +84,7 @@ urlpatterns = patterns('', url(r'^manytomany/(?P\d+)/$', ManyToManyDetail.as_view(), name='manytomanymodel-detail'), url(r'^posts/(?P\d+)/$', BlogPostDetail.as_view(), name='blogpost-detail'), url(r'^comments/$', BlogPostCommentListCreate.as_view(), name='blogpostcomment-list'), + url(r'^comments/(?P\d+)/$', BlogPostCommentDetail.as_view(), name='blogpostcomment-detail'), url(r'^albums/(?P\w[\w-]*)/$', AlbumDetail.as_view(), name='album-detail'), url(r'^photos/$', PhotoListCreate.as_view(), name='photo-list'), url(r'^optionalrelation/(?P<pk>\d+)/$', OptionalRelationDetail.as_view(), name='optionalrelationmodel-detail'), @@ -191,6 +196,7 @@ class TestCreateWithForeignKeys(TestCase): request = factory.post('/comments/', data=data) response = self.create_view(request).render() self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(response['Location'], 'http://testserver/comments/1/') self.assertEqual(self.post.blogpostcomment_set.count(), 1) self.assertEqual(self.post.blogpostcomment_set.all()[0].text, 'A test comment') @@ -215,6 +221,7 @@ class TestCreateWithForeignKeysAndCustomSlug(TestCase): request = factory.post('/photos/', data=data) response = self.list_create_view(request).render() self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertNotIn('Location', response, msg='Location should only be included if there is a "url" field on the serializer') self.assertEqual(self.post.photo_set.count(), 1) self.assertEqual(self.post.photo_set.all()[0].description, 'A test photo') diff --git a/rest_framework/tests/throttling.py b/rest_framework/tests/throttling.py index 0b94c25ba..4b98b9414 100644 --- a/rest_framework/tests/throttling.py +++ b/rest_framework/tests/throttling.py @@ -106,7 +106,7 @@ class ThrottlingTests(TestCase): if expect is not None: self.assertEquals(response['X-Throttle-Wait-Seconds'], expect) else: - self.assertFalse('X-Throttle-Wait-Seconds' in response.headers) + self.assertFalse('X-Throttle-Wait-Seconds' in response) def test_seconds_fields(self): """