mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-07-23 22:49:50 +03:00
Merge ddf2124bb4
into a463ddbb37
This commit is contained in:
commit
986b311a08
|
@ -673,7 +673,7 @@ class HyperlinkedIdentityField(Field):
|
|||
except:
|
||||
pass
|
||||
|
||||
raise ValidationError('Could not resolve URL for field using view name "%s"', view_name)
|
||||
raise ValidationError('Could not resolve URL for field using view name "%s"' % view_name)
|
||||
|
||||
|
||||
##### Typed Fields #####
|
||||
|
|
|
@ -18,16 +18,12 @@ class CreateModelMixin(object):
|
|||
serializer = self.get_serializer(data=request.DATA, files=request.FILES)
|
||||
if serializer.is_valid():
|
||||
self.pre_save(serializer.object)
|
||||
self.object = serializer.save()
|
||||
headers = self.get_success_headers(serializer.data)
|
||||
self.object = serializer.save(force_insert=True)
|
||||
headers = self.get_response_headers(request, status.HTTP_201_CREATED, serializer=serializer)
|
||||
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
|
||||
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
def get_success_headers(self, data):
|
||||
try:
|
||||
return {'Location': data['url']}
|
||||
except (TypeError, KeyError):
|
||||
return {}
|
||||
|
||||
headers = self.get_response_headers(request, status.HTTP_400_BAD_REQUEST, serializer=serializer)
|
||||
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST, headers=headers)
|
||||
|
||||
def pre_save(self, obj):
|
||||
pass
|
||||
|
@ -62,7 +58,8 @@ class ListModelMixin(object):
|
|||
else:
|
||||
serializer = self.get_serializer(self.object_list)
|
||||
|
||||
return Response(serializer.data)
|
||||
headers = self.get_response_headers(request, serializer=serializer)
|
||||
return Response(serializer.data, headers=headers)
|
||||
|
||||
|
||||
class RetrieveModelMixin(object):
|
||||
|
@ -73,7 +70,8 @@ class RetrieveModelMixin(object):
|
|||
def retrieve(self, request, *args, **kwargs):
|
||||
self.object = self.get_object()
|
||||
serializer = self.get_serializer(self.object)
|
||||
return Response(serializer.data)
|
||||
headers = self.get_response_headers(request, serializer=serializer)
|
||||
return Response(serializer.data, headers=headers)
|
||||
|
||||
|
||||
class UpdateModelMixin(object):
|
||||
|
@ -93,11 +91,16 @@ class UpdateModelMixin(object):
|
|||
|
||||
if serializer.is_valid():
|
||||
self.pre_save(serializer.object)
|
||||
self.object = serializer.save()
|
||||
if created:
|
||||
self.object = serializer.save(force_insert=True)
|
||||
else:
|
||||
self.object = serializer.save(force_update=True)
|
||||
status_code = created and status.HTTP_201_CREATED or status.HTTP_200_OK
|
||||
return Response(serializer.data, status=status_code)
|
||||
headers = self.get_response_headers(request, status_code, serializer=serializer)
|
||||
return Response(serializer.data, status=status_code, headers=headers)
|
||||
|
||||
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
||||
headers = self.get_response_headers(request, status.HTTP_400_BAD_REQUEST, serializer=serializer)
|
||||
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST, headers=headers)
|
||||
|
||||
def pre_save(self, obj):
|
||||
"""
|
||||
|
@ -122,4 +125,5 @@ class DestroyModelMixin(object):
|
|||
def destroy(self, request, *args, **kwargs):
|
||||
self.object = self.get_object()
|
||||
self.object.delete()
|
||||
return Response(status=status.HTTP_204_NO_CONTENT)
|
||||
headers = self.get_response_headers(request, status.HTTP_204_NO_CONTENT, object=self.object)
|
||||
return Response(status=status.HTTP_204_NO_CONTENT, headers=headers)
|
||||
|
|
|
@ -10,7 +10,7 @@ import copy
|
|||
import string
|
||||
from django import forms
|
||||
from django.http.multipartparser import parse_header
|
||||
from django.template import RequestContext, loader, Template
|
||||
from django.template import loader, Template
|
||||
from django.utils import simplejson as json
|
||||
from rest_framework.compat import yaml
|
||||
from rest_framework.exceptions import ConfigurationError
|
||||
|
@ -28,7 +28,9 @@ class BaseRenderer(object):
|
|||
All renderers should extend this class, setting the `media_type`
|
||||
and `format` attributes, and override the `.render()` method.
|
||||
"""
|
||||
|
||||
|
||||
context_class = api_settings.DEFAULT_CONTEXT_CLASS
|
||||
|
||||
media_type = None
|
||||
format = None
|
||||
|
||||
|
@ -197,7 +199,7 @@ class TemplateHTMLRenderer(BaseRenderer):
|
|||
def resolve_context(self, data, request, response):
|
||||
if response.exception:
|
||||
data['status_code'] = response.status_code
|
||||
return RequestContext(request, data)
|
||||
return self.context_class(request, data)
|
||||
|
||||
def get_template_names(self, response, view):
|
||||
if response.template_name:
|
||||
|
@ -436,7 +438,7 @@ class BrowsableAPIRenderer(BaseRenderer):
|
|||
breadcrumb_list = get_breadcrumbs(request.path)
|
||||
|
||||
template = loader.get_template(self.template)
|
||||
context = RequestContext(request, {
|
||||
context = self.context_class(request, {
|
||||
'content': content,
|
||||
'view': view,
|
||||
'request': request,
|
||||
|
|
|
@ -127,7 +127,7 @@ class BaseSerializer(Field):
|
|||
"""
|
||||
return {}
|
||||
|
||||
def get_fields(self):
|
||||
def get_all_fields(self):
|
||||
"""
|
||||
Returns the complete set of fields for the object as a dict.
|
||||
|
||||
|
@ -149,6 +149,15 @@ class BaseSerializer(Field):
|
|||
if key not in ret:
|
||||
ret[key] = val
|
||||
|
||||
return ret
|
||||
|
||||
def get_fields(self):
|
||||
"""
|
||||
Returns a subset of fields specified by exclude and fields attr.
|
||||
of the Option Class.
|
||||
"""
|
||||
ret = self.get_all_fields()
|
||||
|
||||
# If 'fields' is specified, use those fields, in that order.
|
||||
if self.opts.fields:
|
||||
new = SortedDict()
|
||||
|
@ -321,11 +330,11 @@ class BaseSerializer(Field):
|
|||
self._data = self.to_native(self.object)
|
||||
return self._data
|
||||
|
||||
def save(self):
|
||||
def save(self, **kwargs):
|
||||
"""
|
||||
Save the deserialized object and return it.
|
||||
"""
|
||||
self.object.save()
|
||||
"""
|
||||
self.object.save(**kwargs)
|
||||
return self.object
|
||||
|
||||
|
||||
|
@ -491,11 +500,11 @@ class ModelSerializer(Serializer):
|
|||
self.m2m_data[field.name] = attrs.pop(field.name)
|
||||
return self.opts.model(**attrs)
|
||||
|
||||
def save(self, save_m2m=True):
|
||||
def save(self, save_m2m=True, **kwargs):
|
||||
"""
|
||||
Save the deserialized object and return it.
|
||||
"""
|
||||
self.object.save()
|
||||
super(ModelSerializer, self).save(**kwargs)
|
||||
|
||||
if getattr(self, 'm2m_data', None) and save_m2m:
|
||||
for accessor_name, object_list in self.m2m_data.items():
|
||||
|
|
|
@ -54,6 +54,9 @@ DEFAULTS = {
|
|||
'user': None,
|
||||
'anon': None,
|
||||
},
|
||||
|
||||
'DEFAULT_CONTEXT_CLASS':
|
||||
'django.template.RequestContext',
|
||||
|
||||
# Pagination
|
||||
'PAGINATE_BY': None,
|
||||
|
@ -84,6 +87,7 @@ IMPORT_STRINGS = (
|
|||
'DEFAULT_AUTHENTICATION_CLASSES',
|
||||
'DEFAULT_PERMISSION_CLASSES',
|
||||
'DEFAULT_THROTTLE_CLASSES',
|
||||
'DEFAULT_CONTEXT_CLASS',
|
||||
'DEFAULT_CONTENT_NEGOTIATION_CLASS',
|
||||
'DEFAULT_MODEL_SERIALIZER_CLASS',
|
||||
'DEFAULT_PAGINATION_SERIALIZER_CLASS',
|
||||
|
|
|
@ -221,7 +221,6 @@ class TestCreateWithForeignKeysAndCustomSlug(TestCase):
|
|||
request = factory.post('/photos/', data=data)
|
||||
response = self.list_create_view(request).render()
|
||||
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
|
||||
self.assertNotIn('Location', response, msg='Location should only be included if there is a "url" field on the serializer')
|
||||
self.assertEqual(self.post.photo_set.count(), 1)
|
||||
self.assertEqual(self.post.photo_set.all()[0].description, 'A test photo')
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ from rest_framework.compat import View, apply_markdown
|
|||
from rest_framework.response import Response
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.settings import api_settings
|
||||
from rest_framework import fields
|
||||
|
||||
|
||||
def _remove_trailing_string(content, trailing):
|
||||
|
@ -291,6 +292,23 @@ class APIView(View):
|
|||
# Perform content negotiation and store the accepted info on the request
|
||||
neg = self.perform_content_negotiation(request)
|
||||
request.accepted_renderer, request.accepted_media_type = neg
|
||||
|
||||
def get_response_headers(self, request, status_code=status.HTTP_200_OK, serializer=None, object=None):
|
||||
headers = {}
|
||||
|
||||
obj = object or (serializer and serializer.object)
|
||||
serializer_fields = serializer and serializer.get_all_fields()
|
||||
|
||||
if status_code == status.HTTP_201_CREATED:
|
||||
|
||||
if obj and hasattr(obj, 'get_absolute_url'):
|
||||
headers['Location'] = obj.get_absolute_url()
|
||||
elif obj and serializer_fields:
|
||||
for field_name, field in serializer_fields.iteritems():
|
||||
if isinstance(field, fields.HyperlinkedIdentityField):
|
||||
headers['Location'] = field.field_to_native(obj, field_name)
|
||||
|
||||
return headers
|
||||
|
||||
def finalize_response(self, request, response, *args, **kwargs):
|
||||
"""
|
||||
|
@ -371,4 +389,5 @@ class APIView(View):
|
|||
We may as well implement this as Django will otherwise provide
|
||||
a less useful default implementation.
|
||||
"""
|
||||
return Response(self.metadata(request), status=status.HTTP_200_OK)
|
||||
headers = self.get_response_headers(request)
|
||||
return Response(self.metadata(request), status=status.HTTP_200_OK, headers=headers)
|
||||
|
|
Loading…
Reference in New Issue
Block a user