mirror of
				https://github.com/encode/django-rest-framework.git
				synced 2025-11-04 01:47:59 +03:00 
			
		
		
		
	Merge pull request #931 from mindlace-mp/writable-nested-modelserializer
Merged master into writable-nested-modelserializer.
This commit is contained in:
		
						commit
						170709442b
					
				| 
						 | 
					@ -60,7 +60,7 @@ The following attributes control the basic view behavior.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
* `queryset` - The queryset that should be used for returning objects from this view.  Typically, you must either set this attribute, or override the `get_queryset()` method.
 | 
					* `queryset` - The queryset that should be used for returning objects from this view.  Typically, you must either set this attribute, or override the `get_queryset()` method.
 | 
				
			||||||
* `serializer_class` - The serializer class that should be used for validating and deserializing input, and for serializing output.  Typically, you must either set this attribute, or override the `get_serializer_class()` method.
 | 
					* `serializer_class` - The serializer class that should be used for validating and deserializing input, and for serializing output.  Typically, you must either set this attribute, or override the `get_serializer_class()` method.
 | 
				
			||||||
* `lookup_field` - The field that should be used to lookup individual model instances.  Defaults to `'pk'`.  The URL conf should include a keyword argument corresponding to this value.  More complex lookup styles can be supported by overriding the `get_object()` method.
 | 
					* `lookup_field` - The field that should be used to lookup individual model instances.  Defaults to `'pk'`.  The URL conf should include a keyword argument corresponding to this value.  More complex lookup styles can be supported by overriding the `get_object()` method.  Note that when using hyperlinked APIs you'll need to ensure that *both* the API views *and* the serializer classes use lookup fields that correctly correspond with the URL conf.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
**Shortcuts**:
 | 
					**Shortcuts**:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -131,7 +131,7 @@ You may want to override this method to provide more complex behavior such as mo
 | 
				
			||||||
For example:
 | 
					For example:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_paginate_by(self):
 | 
					    def get_paginate_by(self):
 | 
				
			||||||
        self.request.accepted_renderer.format == 'html':
 | 
					        if self.request.accepted_renderer.format == 'html':
 | 
				
			||||||
            return 20
 | 
					            return 20
 | 
				
			||||||
        return 100
 | 
					        return 100
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -39,7 +39,7 @@ Declaring a serializer looks very similar to declaring a form:
 | 
				
			||||||
            an existing model instance, or create a new model instance.
 | 
					            an existing model instance, or create a new model instance.
 | 
				
			||||||
            """
 | 
					            """
 | 
				
			||||||
            if instance is not None:
 | 
					            if instance is not None:
 | 
				
			||||||
                instance.title = attrs.get('title', instance.title)
 | 
					                instance.email = attrs.get('email', instance.email)
 | 
				
			||||||
                instance.content = attrs.get('content', instance.content)
 | 
					                instance.content = attrs.get('content', instance.content)
 | 
				
			||||||
                instance.created = attrs.get('created', instance.created)
 | 
					                instance.created = attrs.get('created', instance.created)
 | 
				
			||||||
                return instance
 | 
					                return instance
 | 
				
			||||||
| 
						 | 
					@ -387,7 +387,7 @@ There needs to be a way of determining which views should be used for hyperlinki
 | 
				
			||||||
 | 
					
 | 
				
			||||||
By default hyperlinks are expected to correspond to a view name that matches the style `'{model_name}-detail'`, and looks up the instance by a `pk` keyword argument.
 | 
					By default hyperlinks are expected to correspond to a view name that matches the style `'{model_name}-detail'`, and looks up the instance by a `pk` keyword argument.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
You can change the field that is used for object lookups by setting the `lookup_field` option.  The value of this option should correspond both with a kwarg in the URL conf, and with an field on the model.  For example:
 | 
					You can change the field that is used for object lookups by setting the `lookup_field` option.  The value of this option should correspond both with a kwarg in the URL conf, and with a field on the model.  For example:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    class AccountSerializer(serializers.HyperlinkedModelSerializer):
 | 
					    class AccountSerializer(serializers.HyperlinkedModelSerializer):
 | 
				
			||||||
        class Meta:
 | 
					        class Meta:
 | 
				
			||||||
| 
						 | 
					@ -395,6 +395,8 @@ You can change the field that is used for object lookups by setting the `lookup_
 | 
				
			||||||
            fields = ('url', 'account_name', 'users', 'created')
 | 
					            fields = ('url', 'account_name', 'users', 'created')
 | 
				
			||||||
            lookup_field = 'slug'
 | 
					            lookup_field = 'slug'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Not that the `lookup_field` will be used as the default on *all* hyperlinked fields, including both the URL identity, and any hyperlinked relationships.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
For more specfic requirements such as specifying a different lookup for each field, you'll want to set the fields on the serializer explicitly.  For example:
 | 
					For more specfic requirements such as specifying a different lookup for each field, you'll want to set the fields on the serializer explicitly.  For example:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    class AccountSerializer(serializers.HyperlinkedModelSerializer):
 | 
					    class AccountSerializer(serializers.HyperlinkedModelSerializer):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -209,8 +209,6 @@ To create a base viewset class that provides `create`, `list` and `retrieve` ope
 | 
				
			||||||
                                    mixins.ListMixin,
 | 
					                                    mixins.ListMixin,
 | 
				
			||||||
                                    mixins.RetrieveMixin,
 | 
					                                    mixins.RetrieveMixin,
 | 
				
			||||||
                                    viewsets.GenericViewSet):
 | 
					                                    viewsets.GenericViewSet):
 | 
				
			||||||
        pass
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        A viewset that provides `retrieve`, `update`, and `list` actions.
 | 
					        A viewset that provides `retrieve`, `update`, and `list` actions.
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -139,6 +139,9 @@ The following people have helped make REST framework great.
 | 
				
			||||||
* Pascal Borreli - [pborreli]
 | 
					* Pascal Borreli - [pborreli]
 | 
				
			||||||
* Alex Burgel - [aburgel]
 | 
					* Alex Burgel - [aburgel]
 | 
				
			||||||
* David Medina - [copitux]
 | 
					* David Medina - [copitux]
 | 
				
			||||||
 | 
					* Areski Belaid - [areski]
 | 
				
			||||||
 | 
					* Ethan Freman - [mindlace]
 | 
				
			||||||
 | 
					* David Sanders - [davesque]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Many thanks to everyone who's contributed to the project.
 | 
					Many thanks to everyone who's contributed to the project.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -314,3 +317,7 @@ You can also contact [@_tomchristie][twitter] directly on twitter.
 | 
				
			||||||
[pborreli]: https://github.com/pborreli
 | 
					[pborreli]: https://github.com/pborreli
 | 
				
			||||||
[aburgel]: https://github.com/aburgel
 | 
					[aburgel]: https://github.com/aburgel
 | 
				
			||||||
[copitux]: https://github.com/copitux
 | 
					[copitux]: https://github.com/copitux
 | 
				
			||||||
 | 
					[areski]: https://github.com/areski
 | 
				
			||||||
 | 
					[mindlace]: https://github.com/mindlace
 | 
				
			||||||
 | 
					[davesque]: https://github.com/davesque
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -10,7 +10,9 @@ A `ViewSet` class is only bound to a set of method handlers at the last moment,
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Let's take our current set of views, and refactor them into view sets.
 | 
					Let's take our current set of views, and refactor them into view sets.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
First of all let's refactor our `UserListView` and `UserDetailView` views into a single `UserViewSet`.  We can remove the two views, and replace then with a single class:
 | 
					First of all let's refactor our `UserList` and `UserDetail` views into a single `UserViewSet`.  We can remove the two views, and replace then with a single class:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    from rest_framework import viewsets
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    class UserViewSet(viewsets.ReadOnlyModelViewSet):
 | 
					    class UserViewSet(viewsets.ReadOnlyModelViewSet):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
| 
						 | 
					@ -23,7 +25,6 @@ Here we've used `ReadOnlyModelViewSet` class to automatically provide the defaul
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Next we're going to replace the `SnippetList`, `SnippetDetail` and `SnippetHighlight` view classes.  We can remove the three views, and again replace them with a single class.
 | 
					Next we're going to replace the `SnippetList`, `SnippetDetail` and `SnippetHighlight` view classes.  We can remove the three views, and again replace them with a single class.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    from rest_framework import viewsets
 | 
					 | 
				
			||||||
    from rest_framework.decorators import link
 | 
					    from rest_framework.decorators import link
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    class SnippetViewSet(viewsets.ModelViewSet):
 | 
					    class SnippetViewSet(viewsets.ModelViewSet):
 | 
				
			||||||
| 
						 | 
					@ -73,7 +74,7 @@ In the `urls.py` file we bind our `ViewSet` classes into a set of concrete views
 | 
				
			||||||
    })
 | 
					    })
 | 
				
			||||||
    snippet_highlight = SnippetViewSet.as_view({
 | 
					    snippet_highlight = SnippetViewSet.as_view({
 | 
				
			||||||
        'get': 'highlight'
 | 
					        'get': 'highlight'
 | 
				
			||||||
    })
 | 
					    }, renderer_classes=[renderers.StaticHTMLRenderer])
 | 
				
			||||||
    user_list = UserViewSet.as_view({
 | 
					    user_list = UserViewSet.as_view({
 | 
				
			||||||
        'get': 'list'
 | 
					        'get': 'list'
 | 
				
			||||||
    })
 | 
					    })
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -230,8 +230,9 @@ class OAuthAuthentication(BaseAuthentication):
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            consumer_key = oauth_request.get_parameter('oauth_consumer_key')
 | 
					            consumer_key = oauth_request.get_parameter('oauth_consumer_key')
 | 
				
			||||||
            consumer = oauth_provider_store.get_consumer(request, oauth_request, consumer_key)
 | 
					            consumer = oauth_provider_store.get_consumer(request, oauth_request, consumer_key)
 | 
				
			||||||
        except oauth_provider.store.InvalidConsumerError as err:
 | 
					        except oauth_provider.store.InvalidConsumerError:
 | 
				
			||||||
            raise exceptions.AuthenticationFailed(err)
 | 
					            msg = 'Invalid consumer token: %s' % oauth_request.get_parameter('oauth_consumer_key')
 | 
				
			||||||
 | 
					            raise exceptions.AuthenticationFailed(msg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if consumer.status != oauth_provider.consts.ACCEPTED:
 | 
					        if consumer.status != oauth_provider.consts.ACCEPTED:
 | 
				
			||||||
            msg = 'Invalid consumer key status: %s' % consumer.get_status_display()
 | 
					            msg = 'Invalid consumer key status: %s' % consumer.get_status_display()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -86,10 +86,3 @@ class Throttled(APIException):
 | 
				
			||||||
            self.detail = format % (self.wait, self.wait != 1 and 's' or '')
 | 
					            self.detail = format % (self.wait, self.wait != 1 and 's' or '')
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            self.detail = detail or self.default_detail
 | 
					            self.detail = detail or self.default_detail
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class ConfigurationError(Exception):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Indicates an internal server error.
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    pass
 | 
					 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -7,25 +7,24 @@ from __future__ import unicode_literals
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import copy
 | 
					import copy
 | 
				
			||||||
import datetime
 | 
					import datetime
 | 
				
			||||||
from decimal import Decimal, DecimalException
 | 
					 | 
				
			||||||
import inspect
 | 
					import inspect
 | 
				
			||||||
import re
 | 
					import re
 | 
				
			||||||
import warnings
 | 
					import warnings
 | 
				
			||||||
 | 
					from decimal import Decimal, DecimalException
 | 
				
			||||||
 | 
					from django import forms
 | 
				
			||||||
from django.core import validators
 | 
					from django.core import validators
 | 
				
			||||||
from django.core.exceptions import ValidationError
 | 
					from django.core.exceptions import ValidationError
 | 
				
			||||||
from django.conf import settings
 | 
					from django.conf import settings
 | 
				
			||||||
from django.db.models.fields import BLANK_CHOICE_DASH
 | 
					from django.db.models.fields import BLANK_CHOICE_DASH
 | 
				
			||||||
from django import forms
 | 
					 | 
				
			||||||
from django.forms import widgets
 | 
					from django.forms import widgets
 | 
				
			||||||
from django.utils.encoding import is_protected_type
 | 
					from django.utils.encoding import is_protected_type
 | 
				
			||||||
from django.utils.translation import ugettext_lazy as _
 | 
					from django.utils.translation import ugettext_lazy as _
 | 
				
			||||||
from django.utils.datastructures import SortedDict
 | 
					from django.utils.datastructures import SortedDict
 | 
				
			||||||
from rest_framework import ISO_8601
 | 
					from rest_framework import ISO_8601
 | 
				
			||||||
from rest_framework.compat import (timezone, parse_date, parse_datetime,
 | 
					from rest_framework.compat import (
 | 
				
			||||||
                                   parse_time)
 | 
					    timezone, parse_date, parse_datetime, parse_time, BytesIO, six, smart_text,
 | 
				
			||||||
from rest_framework.compat import BytesIO
 | 
					    force_text, is_non_str_iterable
 | 
				
			||||||
from rest_framework.compat import six
 | 
					)
 | 
				
			||||||
from rest_framework.compat import smart_text, force_text, is_non_str_iterable
 | 
					 | 
				
			||||||
from rest_framework.settings import api_settings
 | 
					from rest_framework.settings import api_settings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -256,6 +255,12 @@ class WritableField(Field):
 | 
				
			||||||
            widget = widget()
 | 
					            widget = widget()
 | 
				
			||||||
        self.widget = widget
 | 
					        self.widget = widget
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __deepcopy__(self, memo):
 | 
				
			||||||
 | 
					        result = copy.copy(self)
 | 
				
			||||||
 | 
					        memo[id(self)] = result
 | 
				
			||||||
 | 
					        result.validators = self.validators[:]
 | 
				
			||||||
 | 
					        return result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def validate(self, value):
 | 
					    def validate(self, value):
 | 
				
			||||||
        if value in validators.EMPTY_VALUES and self.required:
 | 
					        if value in validators.EMPTY_VALUES and self.required:
 | 
				
			||||||
            raise ValidationError(self.error_messages['required'])
 | 
					            raise ValidationError(self.error_messages['required'])
 | 
				
			||||||
| 
						 | 
					@ -331,9 +336,13 @@ class ModelField(WritableField):
 | 
				
			||||||
            raise ValueError("ModelField requires 'model_field' kwarg")
 | 
					            raise ValueError("ModelField requires 'model_field' kwarg")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.min_length = kwargs.pop('min_length',
 | 
					        self.min_length = kwargs.pop('min_length',
 | 
				
			||||||
                            getattr(self.model_field, 'min_length', None))
 | 
					                                     getattr(self.model_field, 'min_length', None))
 | 
				
			||||||
        self.max_length = kwargs.pop('max_length',
 | 
					        self.max_length = kwargs.pop('max_length',
 | 
				
			||||||
                            getattr(self.model_field, 'max_length', None))
 | 
					                                     getattr(self.model_field, 'max_length', None))
 | 
				
			||||||
 | 
					        self.min_value = kwargs.pop('min_value',
 | 
				
			||||||
 | 
					                                    getattr(self.model_field, 'min_value', None))
 | 
				
			||||||
 | 
					        self.max_value = kwargs.pop('max_value',
 | 
				
			||||||
 | 
					                                    getattr(self.model_field, 'max_value', None))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        super(ModelField, self).__init__(*args, **kwargs)
 | 
					        super(ModelField, self).__init__(*args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -341,6 +350,10 @@ class ModelField(WritableField):
 | 
				
			||||||
            self.validators.append(validators.MinLengthValidator(self.min_length))
 | 
					            self.validators.append(validators.MinLengthValidator(self.min_length))
 | 
				
			||||||
        if self.max_length is not None:
 | 
					        if self.max_length is not None:
 | 
				
			||||||
            self.validators.append(validators.MaxLengthValidator(self.max_length))
 | 
					            self.validators.append(validators.MaxLengthValidator(self.max_length))
 | 
				
			||||||
 | 
					        if self.min_value is not None:
 | 
				
			||||||
 | 
					            self.validators.append(validators.MinValueValidator(self.min_value))
 | 
				
			||||||
 | 
					        if self.max_value is not None:
 | 
				
			||||||
 | 
					            self.validators.append(validators.MaxValueValidator(self.max_value))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def from_native(self, value):
 | 
					    def from_native(self, value):
 | 
				
			||||||
        rel = getattr(self.model_field, "rel", None)
 | 
					        rel = getattr(self.model_field, "rel", None)
 | 
				
			||||||
| 
						 | 
					@ -428,13 +441,6 @@ class SlugField(CharField):
 | 
				
			||||||
    def __init__(self, *args, **kwargs):
 | 
					    def __init__(self, *args, **kwargs):
 | 
				
			||||||
        super(SlugField, self).__init__(*args, **kwargs)
 | 
					        super(SlugField, self).__init__(*args, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __deepcopy__(self, memo):
 | 
					 | 
				
			||||||
        result = copy.copy(self)
 | 
					 | 
				
			||||||
        memo[id(self)] = result
 | 
					 | 
				
			||||||
        #result.widget = copy.deepcopy(self.widget, memo)
 | 
					 | 
				
			||||||
        result.validators = self.validators[:]
 | 
					 | 
				
			||||||
        return result
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ChoiceField(WritableField):
 | 
					class ChoiceField(WritableField):
 | 
				
			||||||
    type_name = 'ChoiceField'
 | 
					    type_name = 'ChoiceField'
 | 
				
			||||||
| 
						 | 
					@ -503,13 +509,6 @@ class EmailField(CharField):
 | 
				
			||||||
            return None
 | 
					            return None
 | 
				
			||||||
        return ret.strip()
 | 
					        return ret.strip()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __deepcopy__(self, memo):
 | 
					 | 
				
			||||||
        result = copy.copy(self)
 | 
					 | 
				
			||||||
        memo[id(self)] = result
 | 
					 | 
				
			||||||
        #result.widget = copy.deepcopy(self.widget, memo)
 | 
					 | 
				
			||||||
        result.validators = self.validators[:]
 | 
					 | 
				
			||||||
        return result
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
class RegexField(CharField):
 | 
					class RegexField(CharField):
 | 
				
			||||||
    type_name = 'RegexField'
 | 
					    type_name = 'RegexField'
 | 
				
			||||||
| 
						 | 
					@ -534,12 +533,6 @@ class RegexField(CharField):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    regex = property(_get_regex, _set_regex)
 | 
					    regex = property(_get_regex, _set_regex)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __deepcopy__(self, memo):
 | 
					 | 
				
			||||||
        result = copy.copy(self)
 | 
					 | 
				
			||||||
        memo[id(self)] = result
 | 
					 | 
				
			||||||
        result.validators = self.validators[:]
 | 
					 | 
				
			||||||
        return result
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
class DateField(WritableField):
 | 
					class DateField(WritableField):
 | 
				
			||||||
    type_name = 'DateField'
 | 
					    type_name = 'DateField'
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -212,7 +212,7 @@ class GenericAPIView(views.APIView):
 | 
				
			||||||
        You may want to override this if you need to provide different
 | 
					        You may want to override this if you need to provide different
 | 
				
			||||||
        serializations depending on the incoming request.
 | 
					        serializations depending on the incoming request.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        (Eg. admins get full serialization, others get basic serilization)
 | 
					        (Eg. admins get full serialization, others get basic serialization)
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        serializer_class = self.serializer_class
 | 
					        serializer_class = self.serializer_class
 | 
				
			||||||
        if serializer_class is not None:
 | 
					        if serializer_class is not None:
 | 
				
			||||||
| 
						 | 
					@ -285,7 +285,7 @@ class GenericAPIView(views.APIView):
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            filter_kwargs = {self.slug_field: slug}
 | 
					            filter_kwargs = {self.slug_field: slug}
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            raise exceptions.ConfigurationError(
 | 
					            raise ImproperlyConfigured(
 | 
				
			||||||
                'Expected view %s to be called with a URL keyword argument '
 | 
					                'Expected view %s to be called with a URL keyword argument '
 | 
				
			||||||
                'named "%s". Fix your URL conf, or set the `.lookup_field` '
 | 
					                'named "%s". Fix your URL conf, or set the `.lookup_field` '
 | 
				
			||||||
                'attribute on the view correctly.' %
 | 
					                'attribute on the view correctly.' %
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -11,6 +11,7 @@ from __future__ import unicode_literals
 | 
				
			||||||
import copy
 | 
					import copy
 | 
				
			||||||
import json
 | 
					import json
 | 
				
			||||||
from django import forms
 | 
					from django import forms
 | 
				
			||||||
 | 
					from django.core.exceptions import ImproperlyConfigured
 | 
				
			||||||
from django.http.multipartparser import parse_header
 | 
					from django.http.multipartparser import parse_header
 | 
				
			||||||
from django.template import RequestContext, loader, Template
 | 
					from django.template import RequestContext, loader, Template
 | 
				
			||||||
from django.utils.xmlutils import SimplerXMLGenerator
 | 
					from django.utils.xmlutils import SimplerXMLGenerator
 | 
				
			||||||
| 
						 | 
					@ -18,7 +19,6 @@ from rest_framework.compat import StringIO
 | 
				
			||||||
from rest_framework.compat import six
 | 
					from rest_framework.compat import six
 | 
				
			||||||
from rest_framework.compat import smart_text
 | 
					from rest_framework.compat import smart_text
 | 
				
			||||||
from rest_framework.compat import yaml
 | 
					from rest_framework.compat import yaml
 | 
				
			||||||
from rest_framework.exceptions import ConfigurationError
 | 
					 | 
				
			||||||
from rest_framework.settings import api_settings
 | 
					from rest_framework.settings import api_settings
 | 
				
			||||||
from rest_framework.request import clone_request
 | 
					from rest_framework.request import clone_request
 | 
				
			||||||
from rest_framework.utils import encoders
 | 
					from rest_framework.utils import encoders
 | 
				
			||||||
| 
						 | 
					@ -270,7 +270,7 @@ class TemplateHTMLRenderer(BaseRenderer):
 | 
				
			||||||
            return [self.template_name]
 | 
					            return [self.template_name]
 | 
				
			||||||
        elif hasattr(view, 'get_template_names'):
 | 
					        elif hasattr(view, 'get_template_names'):
 | 
				
			||||||
            return view.get_template_names()
 | 
					            return view.get_template_names()
 | 
				
			||||||
        raise ConfigurationError('Returned a template response with no template_name')
 | 
					        raise ImproperlyConfigured('Returned a template response with no template_name')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_exception_template(self, response):
 | 
					    def get_exception_template(self, response):
 | 
				
			||||||
        template_names = [name % {'status_code': response.status_code}
 | 
					        template_names = [name % {'status_code': response.status_code}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -215,6 +215,7 @@ class DefaultRouter(SimpleRouter):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    include_root_view = True
 | 
					    include_root_view = True
 | 
				
			||||||
    include_format_suffixes = True
 | 
					    include_format_suffixes = True
 | 
				
			||||||
 | 
					    root_view_name = 'api-root'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_api_root_view(self):
 | 
					    def get_api_root_view(self):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
| 
						 | 
					@ -244,7 +245,7 @@ class DefaultRouter(SimpleRouter):
 | 
				
			||||||
        urls = []
 | 
					        urls = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.include_root_view:
 | 
					        if self.include_root_view:
 | 
				
			||||||
            root_url = url(r'^$', self.get_api_root_view(), name='api-root')
 | 
					            root_url = url(r'^$', self.get_api_root_view(), name=self.root_view_name)
 | 
				
			||||||
            urls.append(root_url)
 | 
					            urls.append(root_url)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        default_urls = super(DefaultRouter, self).get_urls()
 | 
					        default_urls = super(DefaultRouter, self).get_urls()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -428,6 +428,47 @@ class OAuthTests(TestCase):
 | 
				
			||||||
        response = self.csrf_client.post('/oauth-with-scope/', params)
 | 
					        response = self.csrf_client.post('/oauth-with-scope/', params)
 | 
				
			||||||
        self.assertEqual(response.status_code, 200)
 | 
					        self.assertEqual(response.status_code, 200)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
 | 
				
			||||||
 | 
					    @unittest.skipUnless(oauth, 'oauth2 not installed')
 | 
				
			||||||
 | 
					    def test_bad_consumer_key(self):
 | 
				
			||||||
 | 
					        """Ensure POSTing using HMAC_SHA1 signature method passes"""
 | 
				
			||||||
 | 
					        params = {
 | 
				
			||||||
 | 
					            'oauth_version': "1.0",
 | 
				
			||||||
 | 
					            'oauth_nonce': oauth.generate_nonce(),
 | 
				
			||||||
 | 
					            'oauth_timestamp': int(time.time()),
 | 
				
			||||||
 | 
					            'oauth_token': self.token.key,
 | 
				
			||||||
 | 
					            'oauth_consumer_key': 'badconsumerkey'
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        signature_method = oauth.SignatureMethod_HMAC_SHA1()
 | 
				
			||||||
 | 
					        req.sign_request(signature_method, self.consumer, self.token)
 | 
				
			||||||
 | 
					        auth = req.to_header()["Authorization"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth)
 | 
				
			||||||
 | 
					        self.assertEqual(response.status_code, 401)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
 | 
				
			||||||
 | 
					    @unittest.skipUnless(oauth, 'oauth2 not installed')
 | 
				
			||||||
 | 
					    def test_bad_token_key(self):
 | 
				
			||||||
 | 
					        """Ensure POSTing using HMAC_SHA1 signature method passes"""
 | 
				
			||||||
 | 
					        params = {
 | 
				
			||||||
 | 
					            'oauth_version': "1.0",
 | 
				
			||||||
 | 
					            'oauth_nonce': oauth.generate_nonce(),
 | 
				
			||||||
 | 
					            'oauth_timestamp': int(time.time()),
 | 
				
			||||||
 | 
					            'oauth_token': 'badtokenkey',
 | 
				
			||||||
 | 
					            'oauth_consumer_key': self.consumer.key
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        signature_method = oauth.SignatureMethod_HMAC_SHA1()
 | 
				
			||||||
 | 
					        req.sign_request(signature_method, self.consumer, self.token)
 | 
				
			||||||
 | 
					        auth = req.to_header()["Authorization"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth)
 | 
				
			||||||
 | 
					        self.assertEqual(response.status_code, 401)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class OAuth2Tests(TestCase):
 | 
					class OAuth2Tests(TestCase):
 | 
				
			||||||
    """OAuth 2.0 authentication"""
 | 
					    """OAuth 2.0 authentication"""
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -866,3 +866,33 @@ class FieldCallableDefault(TestCase):
 | 
				
			||||||
        into = {}
 | 
					        into = {}
 | 
				
			||||||
        field.field_from_native({}, {}, 'field', into)
 | 
					        field.field_from_native({}, {}, 'field', into)
 | 
				
			||||||
        self.assertEqual(into, {'field': 'foo bar'})
 | 
					        self.assertEqual(into, {'field': 'foo bar'})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class CustomIntegerField(TestCase):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					        Test that custom fields apply min_value and max_value constraints
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    def test_custom_fields_can_be_validated_for_value(self):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        class MoneyField(models.PositiveIntegerField):
 | 
				
			||||||
 | 
					            pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        class EntryModel(models.Model):
 | 
				
			||||||
 | 
					            bank = MoneyField(validators=[validators.MaxValueValidator(100)])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        class EntrySerializer(serializers.ModelSerializer):
 | 
				
			||||||
 | 
					            class Meta:
 | 
				
			||||||
 | 
					                model = EntryModel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        entry = EntryModel(bank=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        serializer = EntrySerializer(entry, data={"bank": 11})
 | 
				
			||||||
 | 
					        self.assertTrue(serializer.is_valid())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        serializer = EntrySerializer(entry, data={"bank": -1})
 | 
				
			||||||
 | 
					        self.assertFalse(serializer.is_valid())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        serializer = EntrySerializer(entry, data={"bank": 101})
 | 
				
			||||||
 | 
					        self.assertFalse(serializer.is_valid())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -6,7 +6,7 @@ from rest_framework import serializers, viewsets
 | 
				
			||||||
from rest_framework.compat import include, patterns, url
 | 
					from rest_framework.compat import include, patterns, url
 | 
				
			||||||
from rest_framework.decorators import link, action
 | 
					from rest_framework.decorators import link, action
 | 
				
			||||||
from rest_framework.response import Response
 | 
					from rest_framework.response import Response
 | 
				
			||||||
from rest_framework.routers import SimpleRouter
 | 
					from rest_framework.routers import SimpleRouter, DefaultRouter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
factory = RequestFactory()
 | 
					factory = RequestFactory()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -148,3 +148,17 @@ class TestTrailingSlash(TestCase):
 | 
				
			||||||
        expected = ['^notes$', '^notes/(?P<pk>[^/]+)$']
 | 
					        expected = ['^notes$', '^notes/(?P<pk>[^/]+)$']
 | 
				
			||||||
        for idx in range(len(expected)):
 | 
					        for idx in range(len(expected)):
 | 
				
			||||||
            self.assertEqual(expected[idx], self.urls[idx].regex.pattern)
 | 
					            self.assertEqual(expected[idx], self.urls[idx].regex.pattern)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TestNameableRoot(TestCase):
 | 
				
			||||||
 | 
					    def setUp(self):
 | 
				
			||||||
 | 
					        class NoteViewSet(viewsets.ModelViewSet):
 | 
				
			||||||
 | 
					            model = RouterTestModel
 | 
				
			||||||
 | 
					        self.router = DefaultRouter()
 | 
				
			||||||
 | 
					        self.router.root_view_name = 'nameable-root'
 | 
				
			||||||
 | 
					        self.router.register(r'notes', NoteViewSet)
 | 
				
			||||||
 | 
					        self.urls = self.router.urls
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_router_has_custom_name(self):
 | 
				
			||||||
 | 
					        expected = 'nameable-root'
 | 
				
			||||||
 | 
					        self.assertEqual(expected, self.urls[0].name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -3,7 +3,7 @@ Provides various throttling policies.
 | 
				
			||||||
"""
 | 
					"""
 | 
				
			||||||
from __future__ import unicode_literals
 | 
					from __future__ import unicode_literals
 | 
				
			||||||
from django.core.cache import cache
 | 
					from django.core.cache import cache
 | 
				
			||||||
from rest_framework import exceptions
 | 
					from django.core.exceptions import ImproperlyConfigured
 | 
				
			||||||
from rest_framework.settings import api_settings
 | 
					from rest_framework.settings import api_settings
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -65,13 +65,13 @@ class SimpleRateThrottle(BaseThrottle):
 | 
				
			||||||
        if not getattr(self, 'scope', None):
 | 
					        if not getattr(self, 'scope', None):
 | 
				
			||||||
            msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
 | 
					            msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
 | 
				
			||||||
                   self.__class__.__name__)
 | 
					                   self.__class__.__name__)
 | 
				
			||||||
            raise exceptions.ConfigurationError(msg)
 | 
					            raise ImproperlyConfigured(msg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            return self.settings.DEFAULT_THROTTLE_RATES[self.scope]
 | 
					            return self.settings.DEFAULT_THROTTLE_RATES[self.scope]
 | 
				
			||||||
        except KeyError:
 | 
					        except KeyError:
 | 
				
			||||||
            msg = "No default throttle rate set for '%s' scope" % self.scope
 | 
					            msg = "No default throttle rate set for '%s' scope" % self.scope
 | 
				
			||||||
            raise exceptions.ConfigurationError(msg)
 | 
					            raise ImproperlyConfigured(msg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def parse_rate(self, rate):
 | 
					    def parse_rate(self, rate):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -304,10 +304,10 @@ class APIView(View):
 | 
				
			||||||
        `.dispatch()` is pretty much the same as Django's regular dispatch,
 | 
					        `.dispatch()` is pretty much the same as Django's regular dispatch,
 | 
				
			||||||
        but with extra hooks for startup, finalize, and exception handling.
 | 
					        but with extra hooks for startup, finalize, and exception handling.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        request = self.initialize_request(request, *args, **kwargs)
 | 
					 | 
				
			||||||
        self.request = request
 | 
					 | 
				
			||||||
        self.args = args
 | 
					        self.args = args
 | 
				
			||||||
        self.kwargs = kwargs
 | 
					        self.kwargs = kwargs
 | 
				
			||||||
 | 
					        request = self.initialize_request(request, *args, **kwargs)
 | 
				
			||||||
 | 
					        self.request = request
 | 
				
			||||||
        self.headers = self.default_response_headers  # deprecate?
 | 
					        self.headers = self.default_response_headers  # deprecate?
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
| 
						 | 
					@ -341,8 +341,15 @@ class APIView(View):
 | 
				
			||||||
        Return a dictionary of metadata about the view.
 | 
					        Return a dictionary of metadata about the view.
 | 
				
			||||||
        Used to return responses for OPTIONS requests.
 | 
					        Used to return responses for OPTIONS requests.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # This is used by ViewSets to disambiguate instance vs list views
 | 
				
			||||||
 | 
					        view_name_suffix = getattr(self, 'suffix', None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # By default we can't provide any form-like information, however the
 | 
				
			||||||
 | 
					        # generic views override this implementation and add additional
 | 
				
			||||||
 | 
					        # information for POST and PUT methods, based on the serializer.
 | 
				
			||||||
        ret = SortedDict()
 | 
					        ret = SortedDict()
 | 
				
			||||||
        ret['name'] = get_view_name(self.__class__)
 | 
					        ret['name'] = get_view_name(self.__class__, view_name_suffix)
 | 
				
			||||||
        ret['description'] = get_view_description(self.__class__)
 | 
					        ret['description'] = get_view_description(self.__class__)
 | 
				
			||||||
        ret['renders'] = [renderer.media_type for renderer in self.renderer_classes]
 | 
					        ret['renders'] = [renderer.media_type for renderer in self.renderer_classes]
 | 
				
			||||||
        ret['parses'] = [parser.media_type for parser in self.parser_classes]
 | 
					        ret['parses'] = [parser.media_type for parser in self.parser_classes]
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user