mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-07-26 08:00:07 +03:00
moved header creation back to API View
also seperated a get_all_fields() method from serializers get_field() - in order get have access to all available fields
This commit is contained in:
parent
bd07cd8d20
commit
3bb1c2dcca
|
@ -19,8 +19,11 @@ class CreateModelMixin(object):
|
||||||
if serializer.is_valid():
|
if serializer.is_valid():
|
||||||
self.pre_save(serializer.object)
|
self.pre_save(serializer.object)
|
||||||
self.object = serializer.save(force_insert=True)
|
self.object = serializer.save(force_insert=True)
|
||||||
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=serializer.headers)
|
headers = self.get_response_headers(request, status.HTTP_201_CREATED, serializer=serializer)
|
||||||
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST, headers=serializer.headers)
|
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
|
||||||
|
|
||||||
|
headers = self.get_response_headers(request, status.HTTP_400_BAD_REQUEST, serializer=serializer)
|
||||||
|
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST, headers=headers)
|
||||||
|
|
||||||
def pre_save(self, obj):
|
def pre_save(self, obj):
|
||||||
pass
|
pass
|
||||||
|
@ -55,7 +58,8 @@ class ListModelMixin(object):
|
||||||
else:
|
else:
|
||||||
serializer = self.get_serializer(self.object_list)
|
serializer = self.get_serializer(self.object_list)
|
||||||
|
|
||||||
return Response(serializer.data, headers=serializer.headers)
|
headers = self.get_response_headers(request, serializer=serializer)
|
||||||
|
return Response(serializer.data, headers=headers)
|
||||||
|
|
||||||
|
|
||||||
class RetrieveModelMixin(object):
|
class RetrieveModelMixin(object):
|
||||||
|
@ -66,7 +70,8 @@ class RetrieveModelMixin(object):
|
||||||
def retrieve(self, request, *args, **kwargs):
|
def retrieve(self, request, *args, **kwargs):
|
||||||
self.object = self.get_object()
|
self.object = self.get_object()
|
||||||
serializer = self.get_serializer(self.object)
|
serializer = self.get_serializer(self.object)
|
||||||
return Response(serializer.data, headers=serializer.headers)
|
headers = self.get_response_headers(request, serializer=serializer)
|
||||||
|
return Response(serializer.data, headers=headers)
|
||||||
|
|
||||||
|
|
||||||
class UpdateModelMixin(object):
|
class UpdateModelMixin(object):
|
||||||
|
@ -91,9 +96,11 @@ class UpdateModelMixin(object):
|
||||||
else:
|
else:
|
||||||
self.object = serializer.save(force_update=True)
|
self.object = serializer.save(force_update=True)
|
||||||
status_code = created and status.HTTP_201_CREATED or status.HTTP_200_OK
|
status_code = created and status.HTTP_201_CREATED or status.HTTP_200_OK
|
||||||
return Response(serializer.data, status=status_code, headers=serializer.headers)
|
headers = self.get_response_headers(request, status_code, serializer=serializer)
|
||||||
|
return Response(serializer.data, status=status_code, headers=headers)
|
||||||
|
|
||||||
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST, headers=serializer.headers)
|
headers = self.get_response_headers(request, status.HTTP_400_BAD_REQUEST, serializer=serializer)
|
||||||
|
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST, headers=headers)
|
||||||
|
|
||||||
def pre_save(self, obj):
|
def pre_save(self, obj):
|
||||||
"""
|
"""
|
||||||
|
@ -118,4 +125,5 @@ class DestroyModelMixin(object):
|
||||||
def destroy(self, request, *args, **kwargs):
|
def destroy(self, request, *args, **kwargs):
|
||||||
self.object = self.get_object()
|
self.object = self.get_object()
|
||||||
self.object.delete()
|
self.object.delete()
|
||||||
return Response(status=status.HTTP_204_NO_CONTENT)
|
headers = self.get_response_headers(request, status.HTTP_204_NO_CONTENT, object=self.object)
|
||||||
|
return Response(status=status.HTTP_204_NO_CONTENT, headers=headers)
|
||||||
|
|
|
@ -107,8 +107,7 @@ class BaseSerializer(Field):
|
||||||
|
|
||||||
self._data = None
|
self._data = None
|
||||||
self._files = None
|
self._files = None
|
||||||
self._errors = None
|
self._errors = None
|
||||||
self._headers = {}
|
|
||||||
|
|
||||||
#####
|
#####
|
||||||
# Methods to determine which fields to use when (de)serializing objects.
|
# Methods to determine which fields to use when (de)serializing objects.
|
||||||
|
@ -119,7 +118,7 @@ class BaseSerializer(Field):
|
||||||
"""
|
"""
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def get_fields(self):
|
def get_all_fields(self):
|
||||||
"""
|
"""
|
||||||
Returns the complete set of fields for the object as a dict.
|
Returns the complete set of fields for the object as a dict.
|
||||||
|
|
||||||
|
@ -141,6 +140,15 @@ class BaseSerializer(Field):
|
||||||
if key not in ret:
|
if key not in ret:
|
||||||
ret[key] = val
|
ret[key] = val
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def get_fields(self):
|
||||||
|
"""
|
||||||
|
Returns a subset of fields specified by exclude and fields attr.
|
||||||
|
of the Option Class.
|
||||||
|
"""
|
||||||
|
ret = self.get_all_fields()
|
||||||
|
|
||||||
# If 'fields' is specified, use those fields, in that order.
|
# If 'fields' is specified, use those fields, in that order.
|
||||||
if self.opts.fields:
|
if self.opts.fields:
|
||||||
new = SortedDict()
|
new = SortedDict()
|
||||||
|
@ -308,44 +316,9 @@ class BaseSerializer(Field):
|
||||||
def save(self, **kwargs):
|
def save(self, **kwargs):
|
||||||
"""
|
"""
|
||||||
Save the deserialized object and return it.
|
Save the deserialized object and return it.
|
||||||
"""
|
"""
|
||||||
pk_val = self.object._get_pk_val(self.object.__class__._meta)
|
self.object.save(**kwargs)
|
||||||
pk_set = pk_val is not None
|
|
||||||
|
|
||||||
if ((pk_set) and
|
|
||||||
((('force_update' in kwargs) or ('update_fields' in kwargs)) or
|
|
||||||
('force_insert' not in kwargs and self.object.__class__.objects.filter(pk=pk_val).exists()))):
|
|
||||||
created = False
|
|
||||||
else:
|
|
||||||
created = True
|
|
||||||
|
|
||||||
self.object.save(**kwargs)
|
|
||||||
|
|
||||||
if created:
|
|
||||||
self.set_location_header()
|
|
||||||
|
|
||||||
return self.object
|
return self.object
|
||||||
|
|
||||||
def _generate_headers(self):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def headers(self):
|
|
||||||
ret = self._generate_headers()
|
|
||||||
ret.update(self._headers)
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def set_location_header(self):
|
|
||||||
if hasattr(self.object, 'get_absolute_url'):
|
|
||||||
self._headers['Location'] = self.object.get_absolute_url()
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
for field_name, field in self.fields.iteritems():
|
|
||||||
if isinstance(field, HyperlinkedIdentityField):
|
|
||||||
self._headers['Location'] = field.field_to_native(self.object, field_name)
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class Serializer(BaseSerializer):
|
class Serializer(BaseSerializer):
|
||||||
|
|
|
@ -13,6 +13,7 @@ from rest_framework.compat import View, apply_markdown
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from rest_framework.request import Request
|
from rest_framework.request import Request
|
||||||
from rest_framework.settings import api_settings
|
from rest_framework.settings import api_settings
|
||||||
|
from rest_framework import fields
|
||||||
|
|
||||||
|
|
||||||
def _remove_trailing_string(content, trailing):
|
def _remove_trailing_string(content, trailing):
|
||||||
|
@ -291,6 +292,23 @@ class APIView(View):
|
||||||
# Perform content negotiation and store the accepted info on the request
|
# Perform content negotiation and store the accepted info on the request
|
||||||
neg = self.perform_content_negotiation(request)
|
neg = self.perform_content_negotiation(request)
|
||||||
request.accepted_renderer, request.accepted_media_type = neg
|
request.accepted_renderer, request.accepted_media_type = neg
|
||||||
|
|
||||||
|
def get_response_headers(self, request, status_code=status.HTTP_200_OK, serializer=None, object=None):
|
||||||
|
headers = {}
|
||||||
|
|
||||||
|
obj = object or (serializer and serializer.object)
|
||||||
|
serializer_fields = serializer and serializer.get_all_fields()
|
||||||
|
|
||||||
|
if status_code == status.HTTP_201_CREATED:
|
||||||
|
|
||||||
|
if obj and hasattr(obj, 'get_absolute_url'):
|
||||||
|
headers['Location'] = obj.get_absolute_url()
|
||||||
|
elif obj and serializer_fields:
|
||||||
|
for field_name, field in serializer_fields.iteritems():
|
||||||
|
if isinstance(field, fields.HyperlinkedIdentityField):
|
||||||
|
headers['Location'] = field.field_to_native(obj, field_name)
|
||||||
|
|
||||||
|
return headers
|
||||||
|
|
||||||
def finalize_response(self, request, response, *args, **kwargs):
|
def finalize_response(self, request, response, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -371,4 +389,5 @@ class APIView(View):
|
||||||
We may as well implement this as Django will otherwise provide
|
We may as well implement this as Django will otherwise provide
|
||||||
a less useful default implementation.
|
a less useful default implementation.
|
||||||
"""
|
"""
|
||||||
return Response(self.metadata(request), status=status.HTTP_200_OK)
|
headers = self.get_response_headers(request)
|
||||||
|
return Response(self.metadata(request), status=status.HTTP_200_OK, headers=headers)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user