From 8ecd419df7e5635c0d3041047f814b211955c6cc Mon Sep 17 00:00:00 2001 From: Ludwig Kraatz Date: Sat, 1 Dec 2012 13:41:34 +0100 Subject: [PATCH 1/6] header being set by serializer serializer has a .header property now Subclassed serializers can now change headers of the response (e.g. Pagination, Related Links, ...) --- rest_framework/mixins.py | 26 ++++------ rest_framework/serializers.py | 52 +++++++++++++++++-- .../tests/hyperlinkedserializers.py | 1 - 3 files changed, 59 insertions(+), 20 deletions(-) diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 1edcfa5c9..23e241c03 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -18,16 +18,9 @@ class CreateModelMixin(object): serializer = self.get_serializer(data=request.DATA, files=request.FILES) if serializer.is_valid(): self.pre_save(serializer.object) - self.object = serializer.save() - headers = self.get_success_headers(serializer.data) - return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) - return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - - def get_success_headers(self, data): - try: - return {'Location': data['url']} - except (TypeError, KeyError): - return {} + self.object = serializer.save(force_insert=True) + return Response(serializer.data, status=status.HTTP_201_CREATED, headers=serializer.headers) + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST, headers=serializer.headers) def pre_save(self, obj): pass @@ -62,7 +55,7 @@ class ListModelMixin(object): else: serializer = self.get_serializer(self.object_list) - return Response(serializer.data) + return Response(serializer.data, headers=serializer.headers) class RetrieveModelMixin(object): @@ -73,7 +66,7 @@ class RetrieveModelMixin(object): def retrieve(self, request, *args, **kwargs): self.object = self.get_object() serializer = self.get_serializer(self.object) - return Response(serializer.data) + return Response(serializer.data, headers=serializer.headers) class UpdateModelMixin(object): @@ -93,11 +86,14 @@ class UpdateModelMixin(object): if serializer.is_valid(): self.pre_save(serializer.object) - self.object = serializer.save() + if created: + self.object = serializer.save(force_insert=True) + else: + self.object = serializer.save(force_update=True) status_code = created and status.HTTP_201_CREATED or status.HTTP_200_OK - return Response(serializer.data, status=status_code) + return Response(serializer.data, status=status_code, headers=serializer.headers) - return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST, headers=serializer.headers) def pre_save(self, obj): """ diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 4519ab053..672d65c6c 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -108,6 +108,8 @@ class BaseSerializer(Field): self._data = None self._files = None self._errors = None + self._headers = {} + ##### # Methods to determine which fields to use when (de)serializing objects. @@ -304,12 +306,47 @@ class BaseSerializer(Field): self._data = self.to_native(self.object) return self._data - def save(self): + def save(self, **kwargs): """ Save the deserialized object and return it. """ - self.object.save() + pk_val = self.object._get_pk_val(self.object.__class__._meta) + pk_set = pk_val is not None + + if ((pk_set) and + ((('force_update' in kwargs) or ('update_fields' in kwargs)) or + ('force_insert' not in kwargs and self.object.__class__.objects.filter(pk=pk_val).exists()))): + created = False + else: + created = True + + self.object.save(**kwargs) + + if created: + self.set_location_header() + return self.object + + def generate_header(self): + return {} + + @property + def headers(self): + #self._headers.update(self.generate_header()) + return self._headers + + def set_location_header(self): + self._headers['Location'] = 'x' + if hasattr(self.object, 'get_absolute_url'): + self._headers['Location'] = self.object.get_absolute_url() + return True + else: + for field_name, field in self.fields.iteritems(): + if isinstance(field, HyperlinkedIdentityField): + self._headers['Location'] = field.field_to_native(self.object, field_name) + return True + + return False class Serializer(BaseSerializer): @@ -474,11 +511,11 @@ class ModelSerializer(Serializer): self.m2m_data[field.name] = attrs.pop(field.name) return self.opts.model(**attrs) - def save(self, save_m2m=True): + def save(self, save_m2m=True, **kwargs): """ Save the deserialized object and return it. """ - self.object.save() + super(ModelSerializer, self).save(**kwargs) if getattr(self, 'm2m_data', None) and save_m2m: for accessor_name, object_list in self.m2m_data.items(): @@ -539,3 +576,10 @@ class HyperlinkedModelSerializer(ModelSerializer): if to_many: return ManyHyperlinkedRelatedField(**kwargs) return HyperlinkedRelatedField(**kwargs) + + def set_location_header(self): + if not super(HyperlinkedModelSerializer, self).set_location_header(): + self._headers['Location'] = self.data['url'] + return True + + return True \ No newline at end of file diff --git a/rest_framework/tests/hyperlinkedserializers.py b/rest_framework/tests/hyperlinkedserializers.py index d7effce70..9be65992d 100644 --- a/rest_framework/tests/hyperlinkedserializers.py +++ b/rest_framework/tests/hyperlinkedserializers.py @@ -221,7 +221,6 @@ class TestCreateWithForeignKeysAndCustomSlug(TestCase): request = factory.post('/photos/', data=data) response = self.list_create_view(request).render() self.assertEqual(response.status_code, status.HTTP_201_CREATED) - self.assertNotIn('Location', response, msg='Location should only be included if there is a "url" field on the serializer') self.assertEqual(self.post.photo_set.count(), 1) self.assertEqual(self.post.photo_set.all()[0].description, 'A test photo') From 448932c6598006c2b9b3c7010661e7815b92656a Mon Sep 17 00:00:00 2001 From: Ludwig Kraatz Date: Sat, 1 Dec 2012 13:51:05 +0100 Subject: [PATCH 2/6] Support custom Context Class in renderers --- rest_framework/renderers.py | 10 ++++++---- rest_framework/settings.py | 4 ++++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 25a32baae..108c3653c 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -10,7 +10,7 @@ import copy import string from django import forms from django.http.multipartparser import parse_header -from django.template import RequestContext, loader, Template +from django.template import loader, Template from django.utils import simplejson as json from rest_framework.compat import yaml from rest_framework.exceptions import ConfigurationError @@ -28,7 +28,9 @@ class BaseRenderer(object): All renderers should extend this class, setting the `media_type` and `format` attributes, and override the `.render()` method. """ - + + context_class = api_settings.DEFAULT_CONTEXT_CLASS + media_type = None format = None @@ -197,7 +199,7 @@ class TemplateHTMLRenderer(BaseRenderer): def resolve_context(self, data, request, response): if response.exception: data['status_code'] = response.status_code - return RequestContext(request, data) + return self.context_class(request, data) def get_template_names(self, response, view): if response.template_name: @@ -433,7 +435,7 @@ class BrowsableAPIRenderer(BaseRenderer): breadcrumb_list = get_breadcrumbs(request.path) template = loader.get_template(self.template) - context = RequestContext(request, { + context = self.context_class(request, { 'content': content, 'view': view, 'request': request, diff --git a/rest_framework/settings.py b/rest_framework/settings.py index ee24a4ad9..5de19d965 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -54,6 +54,9 @@ DEFAULTS = { 'user': None, 'anon': None, }, + + 'DEFAULT_CONTEXT_CLASS': + 'django.template.RequestContext', # Pagination 'PAGINATE_BY': None, @@ -84,6 +87,7 @@ IMPORT_STRINGS = ( 'DEFAULT_AUTHENTICATION_CLASSES', 'DEFAULT_PERMISSION_CLASSES', 'DEFAULT_THROTTLE_CLASSES', + 'DEFAULT_CONTEXT_CLASS', 'DEFAULT_CONTENT_NEGOTIATION_CLASS', 'DEFAULT_MODEL_SERIALIZER_CLASS', 'DEFAULT_PAGINATION_SERIALIZER_CLASS', From bd07cd8d207b0b2f860a6dae8eac9dd7d158ab48 Mon Sep 17 00:00:00 2001 From: Ludwig Kraatz Date: Sat, 1 Dec 2012 14:02:49 +0100 Subject: [PATCH 3/6] improved code just minor renaming / cleaning up --- rest_framework/serializers.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 672d65c6c..61bb54d3a 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -108,8 +108,7 @@ class BaseSerializer(Field): self._data = None self._files = None self._errors = None - self._headers = {} - + self._headers = {} ##### # Methods to determine which fields to use when (de)serializing objects. @@ -327,16 +326,16 @@ class BaseSerializer(Field): return self.object - def generate_header(self): + def _generate_headers(self): return {} @property def headers(self): - #self._headers.update(self.generate_header()) - return self._headers + ret = self._generate_headers() + ret.update(self._headers) + return ret def set_location_header(self): - self._headers['Location'] = 'x' if hasattr(self.object, 'get_absolute_url'): self._headers['Location'] = self.object.get_absolute_url() return True @@ -576,10 +575,3 @@ class HyperlinkedModelSerializer(ModelSerializer): if to_many: return ManyHyperlinkedRelatedField(**kwargs) return HyperlinkedRelatedField(**kwargs) - - def set_location_header(self): - if not super(HyperlinkedModelSerializer, self).set_location_header(): - self._headers['Location'] = self.data['url'] - return True - - return True \ No newline at end of file From 3bb1c2dcca18628c7112eb9e7d86fd53ea3e83a0 Mon Sep 17 00:00:00 2001 From: Ludwig Kraatz Date: Mon, 3 Dec 2012 08:55:03 +0100 Subject: [PATCH 4/6] moved header creation back to API View also seperated a get_all_fields() method from serializers get_field() - in order get have access to all available fields --- rest_framework/mixins.py | 22 ++++++++++----- rest_framework/serializers.py | 53 +++++++++-------------------------- rest_framework/views.py | 21 +++++++++++++- 3 files changed, 48 insertions(+), 48 deletions(-) diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 23e241c03..0b1d53ea8 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -19,8 +19,11 @@ class CreateModelMixin(object): if serializer.is_valid(): self.pre_save(serializer.object) self.object = serializer.save(force_insert=True) - return Response(serializer.data, status=status.HTTP_201_CREATED, headers=serializer.headers) - return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST, headers=serializer.headers) + headers = self.get_response_headers(request, status.HTTP_201_CREATED, serializer=serializer) + return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) + + headers = self.get_response_headers(request, status.HTTP_400_BAD_REQUEST, serializer=serializer) + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST, headers=headers) def pre_save(self, obj): pass @@ -55,7 +58,8 @@ class ListModelMixin(object): else: serializer = self.get_serializer(self.object_list) - return Response(serializer.data, headers=serializer.headers) + headers = self.get_response_headers(request, serializer=serializer) + return Response(serializer.data, headers=headers) class RetrieveModelMixin(object): @@ -66,7 +70,8 @@ class RetrieveModelMixin(object): def retrieve(self, request, *args, **kwargs): self.object = self.get_object() serializer = self.get_serializer(self.object) - return Response(serializer.data, headers=serializer.headers) + headers = self.get_response_headers(request, serializer=serializer) + return Response(serializer.data, headers=headers) class UpdateModelMixin(object): @@ -91,9 +96,11 @@ class UpdateModelMixin(object): else: self.object = serializer.save(force_update=True) status_code = created and status.HTTP_201_CREATED or status.HTTP_200_OK - return Response(serializer.data, status=status_code, headers=serializer.headers) + headers = self.get_response_headers(request, status_code, serializer=serializer) + return Response(serializer.data, status=status_code, headers=headers) - return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST, headers=serializer.headers) + headers = self.get_response_headers(request, status.HTTP_400_BAD_REQUEST, serializer=serializer) + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST, headers=headers) def pre_save(self, obj): """ @@ -118,4 +125,5 @@ class DestroyModelMixin(object): def destroy(self, request, *args, **kwargs): self.object = self.get_object() self.object.delete() - return Response(status=status.HTTP_204_NO_CONTENT) + headers = self.get_response_headers(request, status.HTTP_204_NO_CONTENT, object=self.object) + return Response(status=status.HTTP_204_NO_CONTENT, headers=headers) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 61bb54d3a..1bba5c9ed 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -107,8 +107,7 @@ class BaseSerializer(Field): self._data = None self._files = None - self._errors = None - self._headers = {} + self._errors = None ##### # Methods to determine which fields to use when (de)serializing objects. @@ -119,7 +118,7 @@ class BaseSerializer(Field): """ return {} - def get_fields(self): + def get_all_fields(self): """ Returns the complete set of fields for the object as a dict. @@ -141,6 +140,15 @@ class BaseSerializer(Field): if key not in ret: ret[key] = val + return ret + + def get_fields(self): + """ + Returns a subset of fields specified by exclude and fields attr. + of the Option Class. + """ + ret = self.get_all_fields() + # If 'fields' is specified, use those fields, in that order. if self.opts.fields: new = SortedDict() @@ -308,44 +316,9 @@ class BaseSerializer(Field): def save(self, **kwargs): """ Save the deserialized object and return it. - """ - pk_val = self.object._get_pk_val(self.object.__class__._meta) - pk_set = pk_val is not None - - if ((pk_set) and - ((('force_update' in kwargs) or ('update_fields' in kwargs)) or - ('force_insert' not in kwargs and self.object.__class__.objects.filter(pk=pk_val).exists()))): - created = False - else: - created = True - - self.object.save(**kwargs) - - if created: - self.set_location_header() - + """ + self.object.save(**kwargs) return self.object - - def _generate_headers(self): - return {} - - @property - def headers(self): - ret = self._generate_headers() - ret.update(self._headers) - return ret - - def set_location_header(self): - if hasattr(self.object, 'get_absolute_url'): - self._headers['Location'] = self.object.get_absolute_url() - return True - else: - for field_name, field in self.fields.iteritems(): - if isinstance(field, HyperlinkedIdentityField): - self._headers['Location'] = field.field_to_native(self.object, field_name) - return True - - return False class Serializer(BaseSerializer): diff --git a/rest_framework/views.py b/rest_framework/views.py index 10bdd5a53..604d95bd3 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -13,6 +13,7 @@ from rest_framework.compat import View, apply_markdown from rest_framework.response import Response from rest_framework.request import Request from rest_framework.settings import api_settings +from rest_framework import fields def _remove_trailing_string(content, trailing): @@ -291,6 +292,23 @@ class APIView(View): # Perform content negotiation and store the accepted info on the request neg = self.perform_content_negotiation(request) request.accepted_renderer, request.accepted_media_type = neg + + def get_response_headers(self, request, status_code=status.HTTP_200_OK, serializer=None, object=None): + headers = {} + + obj = object or (serializer and serializer.object) + serializer_fields = serializer and serializer.get_all_fields() + + if status_code == status.HTTP_201_CREATED: + + if obj and hasattr(obj, 'get_absolute_url'): + headers['Location'] = obj.get_absolute_url() + elif obj and serializer_fields: + for field_name, field in serializer_fields.iteritems(): + if isinstance(field, fields.HyperlinkedIdentityField): + headers['Location'] = field.field_to_native(obj, field_name) + + return headers def finalize_response(self, request, response, *args, **kwargs): """ @@ -371,4 +389,5 @@ class APIView(View): We may as well implement this as Django will otherwise provide a less useful default implementation. """ - return Response(self.metadata(request), status=status.HTTP_200_OK) + headers = self.get_response_headers(request) + return Response(self.metadata(request), status=status.HTTP_200_OK, headers=headers) From d76333b48e9ad679f9da6a01f4652f6fc7cafbd3 Mon Sep 17 00:00:00 2001 From: Ludwig Kraatz Date: Mon, 3 Dec 2012 08:58:17 +0100 Subject: [PATCH 5/6] removed extra whitespace --- rest_framework/serializers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 1bba5c9ed..d7f5f531e 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -107,7 +107,7 @@ class BaseSerializer(Field): self._data = None self._files = None - self._errors = None + self._errors = None ##### # Methods to determine which fields to use when (de)serializing objects. From ddf2124bb4cc255796f3e8a771741ab9e2b9e8b7 Mon Sep 17 00:00:00 2001 From: Ludwig Kraatz Date: Mon, 3 Dec 2012 20:34:35 +0100 Subject: [PATCH 6/6] fixed some wrong string formating --- rest_framework/fields.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 482a3d485..08dd797d7 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -673,7 +673,7 @@ class HyperlinkedIdentityField(Field): except: pass - raise ValidationError('Could not resolve URL for field using view name "%s"', view_name) + raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name) ##### Typed Fields #####