mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-07-27 00:19:53 +03:00
header being set by serializer
serializer has a .header property now Subclassed serializers can now change headers of the response (e.g. Pagination, Related Links, ...)
This commit is contained in:
parent
3114b4fa50
commit
8ecd419df7
|
@ -18,16 +18,9 @@ class CreateModelMixin(object):
|
|||
serializer = self.get_serializer(data=request.DATA, files=request.FILES)
|
||||
if serializer.is_valid():
|
||||
self.pre_save(serializer.object)
|
||||
self.object = serializer.save()
|
||||
headers = self.get_success_headers(serializer.data)
|
||||
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
|
||||
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
def get_success_headers(self, data):
|
||||
try:
|
||||
return {'Location': data['url']}
|
||||
except (TypeError, KeyError):
|
||||
return {}
|
||||
self.object = serializer.save(force_insert=True)
|
||||
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=serializer.headers)
|
||||
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST, headers=serializer.headers)
|
||||
|
||||
def pre_save(self, obj):
|
||||
pass
|
||||
|
@ -62,7 +55,7 @@ class ListModelMixin(object):
|
|||
else:
|
||||
serializer = self.get_serializer(self.object_list)
|
||||
|
||||
return Response(serializer.data)
|
||||
return Response(serializer.data, headers=serializer.headers)
|
||||
|
||||
|
||||
class RetrieveModelMixin(object):
|
||||
|
@ -73,7 +66,7 @@ class RetrieveModelMixin(object):
|
|||
def retrieve(self, request, *args, **kwargs):
|
||||
self.object = self.get_object()
|
||||
serializer = self.get_serializer(self.object)
|
||||
return Response(serializer.data)
|
||||
return Response(serializer.data, headers=serializer.headers)
|
||||
|
||||
|
||||
class UpdateModelMixin(object):
|
||||
|
@ -93,11 +86,14 @@ class UpdateModelMixin(object):
|
|||
|
||||
if serializer.is_valid():
|
||||
self.pre_save(serializer.object)
|
||||
self.object = serializer.save()
|
||||
if created:
|
||||
self.object = serializer.save(force_insert=True)
|
||||
else:
|
||||
self.object = serializer.save(force_update=True)
|
||||
status_code = created and status.HTTP_201_CREATED or status.HTTP_200_OK
|
||||
return Response(serializer.data, status=status_code)
|
||||
return Response(serializer.data, status=status_code, headers=serializer.headers)
|
||||
|
||||
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
||||
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST, headers=serializer.headers)
|
||||
|
||||
def pre_save(self, obj):
|
||||
"""
|
||||
|
|
|
@ -108,6 +108,8 @@ class BaseSerializer(Field):
|
|||
self._data = None
|
||||
self._files = None
|
||||
self._errors = None
|
||||
self._headers = {}
|
||||
|
||||
|
||||
#####
|
||||
# Methods to determine which fields to use when (de)serializing objects.
|
||||
|
@ -304,13 +306,48 @@ class BaseSerializer(Field):
|
|||
self._data = self.to_native(self.object)
|
||||
return self._data
|
||||
|
||||
def save(self):
|
||||
def save(self, **kwargs):
|
||||
"""
|
||||
Save the deserialized object and return it.
|
||||
"""
|
||||
self.object.save()
|
||||
pk_val = self.object._get_pk_val(self.object.__class__._meta)
|
||||
pk_set = pk_val is not None
|
||||
|
||||
if ((pk_set) and
|
||||
((('force_update' in kwargs) or ('update_fields' in kwargs)) or
|
||||
('force_insert' not in kwargs and self.object.__class__.objects.filter(pk=pk_val).exists()))):
|
||||
created = False
|
||||
else:
|
||||
created = True
|
||||
|
||||
self.object.save(**kwargs)
|
||||
|
||||
if created:
|
||||
self.set_location_header()
|
||||
|
||||
return self.object
|
||||
|
||||
def generate_header(self):
|
||||
return {}
|
||||
|
||||
@property
|
||||
def headers(self):
|
||||
#self._headers.update(self.generate_header())
|
||||
return self._headers
|
||||
|
||||
def set_location_header(self):
|
||||
self._headers['Location'] = 'x'
|
||||
if hasattr(self.object, 'get_absolute_url'):
|
||||
self._headers['Location'] = self.object.get_absolute_url()
|
||||
return True
|
||||
else:
|
||||
for field_name, field in self.fields.iteritems():
|
||||
if isinstance(field, HyperlinkedIdentityField):
|
||||
self._headers['Location'] = field.field_to_native(self.object, field_name)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class Serializer(BaseSerializer):
|
||||
__metaclass__ = SerializerMetaclass
|
||||
|
@ -474,11 +511,11 @@ class ModelSerializer(Serializer):
|
|||
self.m2m_data[field.name] = attrs.pop(field.name)
|
||||
return self.opts.model(**attrs)
|
||||
|
||||
def save(self, save_m2m=True):
|
||||
def save(self, save_m2m=True, **kwargs):
|
||||
"""
|
||||
Save the deserialized object and return it.
|
||||
"""
|
||||
self.object.save()
|
||||
super(ModelSerializer, self).save(**kwargs)
|
||||
|
||||
if getattr(self, 'm2m_data', None) and save_m2m:
|
||||
for accessor_name, object_list in self.m2m_data.items():
|
||||
|
@ -539,3 +576,10 @@ class HyperlinkedModelSerializer(ModelSerializer):
|
|||
if to_many:
|
||||
return ManyHyperlinkedRelatedField(**kwargs)
|
||||
return HyperlinkedRelatedField(**kwargs)
|
||||
|
||||
def set_location_header(self):
|
||||
if not super(HyperlinkedModelSerializer, self).set_location_header():
|
||||
self._headers['Location'] = self.data['url']
|
||||
return True
|
||||
|
||||
return True
|
|
@ -221,7 +221,6 @@ class TestCreateWithForeignKeysAndCustomSlug(TestCase):
|
|||
request = factory.post('/photos/', data=data)
|
||||
response = self.list_create_view(request).render()
|
||||
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
|
||||
self.assertNotIn('Location', response, msg='Location should only be included if there is a "url" field on the serializer')
|
||||
self.assertEqual(self.post.photo_set.count(), 1)
|
||||
self.assertEqual(self.post.photo_set.all()[0].description, 'A test photo')
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user