Merge remote-tracking branch 'upstream/master' into writable-nested-modelserializer

This commit is contained in:
Ethan Fremen 2013-06-10 10:29:25 -07:00
commit 0e75bcd259
15 changed files with 107 additions and 60 deletions

View File

@ -60,7 +60,7 @@ The following attributes control the basic view behavior.
* `queryset` - The queryset that should be used for returning objects from this view. Typically, you must either set this attribute, or override the `get_queryset()` method. * `queryset` - The queryset that should be used for returning objects from this view. Typically, you must either set this attribute, or override the `get_queryset()` method.
* `serializer_class` - The serializer class that should be used for validating and deserializing input, and for serializing output. Typically, you must either set this attribute, or override the `get_serializer_class()` method. * `serializer_class` - The serializer class that should be used for validating and deserializing input, and for serializing output. Typically, you must either set this attribute, or override the `get_serializer_class()` method.
* `lookup_field` - The field that should be used to lookup individual model instances. Defaults to `'pk'`. The URL conf should include a keyword argument corresponding to this value. More complex lookup styles can be supported by overriding the `get_object()` method. * `lookup_field` - The field that should be used to lookup individual model instances. Defaults to `'pk'`. The URL conf should include a keyword argument corresponding to this value. More complex lookup styles can be supported by overriding the `get_object()` method. Note that when using hyperlinked APIs you'll need to ensure that *both* the API views *and* the serializer classes use lookup fields that correctly correspond with the URL conf.
**Shortcuts**: **Shortcuts**:
@ -131,7 +131,7 @@ You may want to override this method to provide more complex behavior such as mo
For example: For example:
def get_paginate_by(self): def get_paginate_by(self):
self.request.accepted_renderer.format == 'html': if self.request.accepted_renderer.format == 'html':
return 20 return 20
return 100 return 100

View File

@ -39,7 +39,7 @@ Declaring a serializer looks very similar to declaring a form:
an existing model instance, or create a new model instance. an existing model instance, or create a new model instance.
""" """
if instance is not None: if instance is not None:
instance.title = attrs.get('title', instance.title) instance.email = attrs.get('email', instance.email)
instance.content = attrs.get('content', instance.content) instance.content = attrs.get('content', instance.content)
instance.created = attrs.get('created', instance.created) instance.created = attrs.get('created', instance.created)
return instance return instance
@ -387,7 +387,7 @@ There needs to be a way of determining which views should be used for hyperlinki
By default hyperlinks are expected to correspond to a view name that matches the style `'{model_name}-detail'`, and looks up the instance by a `pk` keyword argument. By default hyperlinks are expected to correspond to a view name that matches the style `'{model_name}-detail'`, and looks up the instance by a `pk` keyword argument.
You can change the field that is used for object lookups by setting the `lookup_field` option. The value of this option should correspond both with a kwarg in the URL conf, and with an field on the model. For example: You can change the field that is used for object lookups by setting the `lookup_field` option. The value of this option should correspond both with a kwarg in the URL conf, and with a field on the model. For example:
class AccountSerializer(serializers.HyperlinkedModelSerializer): class AccountSerializer(serializers.HyperlinkedModelSerializer):
class Meta: class Meta:
@ -395,6 +395,8 @@ You can change the field that is used for object lookups by setting the `lookup_
fields = ('url', 'account_name', 'users', 'created') fields = ('url', 'account_name', 'users', 'created')
lookup_field = 'slug' lookup_field = 'slug'
Not that the `lookup_field` will be used as the default on *all* hyperlinked fields, including both the URL identity, and any hyperlinked relationships.
For more specfic requirements such as specifying a different lookup for each field, you'll want to set the fields on the serializer explicitly. For example: For more specfic requirements such as specifying a different lookup for each field, you'll want to set the fields on the serializer explicitly. For example:
class AccountSerializer(serializers.HyperlinkedModelSerializer): class AccountSerializer(serializers.HyperlinkedModelSerializer):

View File

@ -209,8 +209,6 @@ To create a base viewset class that provides `create`, `list` and `retrieve` ope
mixins.ListMixin, mixins.ListMixin,
mixins.RetrieveMixin, mixins.RetrieveMixin,
viewsets.GenericViewSet): viewsets.GenericViewSet):
pass
""" """
A viewset that provides `retrieve`, `update`, and `list` actions. A viewset that provides `retrieve`, `update`, and `list` actions.

View File

@ -139,6 +139,8 @@ The following people have helped make REST framework great.
* Pascal Borreli - [pborreli] * Pascal Borreli - [pborreli]
* Alex Burgel - [aburgel] * Alex Burgel - [aburgel]
* David Medina - [copitux] * David Medina - [copitux]
* Areski Belaid - [areski]
* Ethan Freman - [mindlace]
Many thanks to everyone who's contributed to the project. Many thanks to everyone who's contributed to the project.
@ -314,3 +316,5 @@ You can also contact [@_tomchristie][twitter] directly on twitter.
[pborreli]: https://github.com/pborreli [pborreli]: https://github.com/pborreli
[aburgel]: https://github.com/aburgel [aburgel]: https://github.com/aburgel
[copitux]: https://github.com/copitux [copitux]: https://github.com/copitux
[areski]: https://github.com/areski
[mindlace]: https://github.com/mindlace

View File

@ -10,7 +10,9 @@ A `ViewSet` class is only bound to a set of method handlers at the last moment,
Let's take our current set of views, and refactor them into view sets. Let's take our current set of views, and refactor them into view sets.
First of all let's refactor our `UserListView` and `UserDetailView` views into a single `UserViewSet`. We can remove the two views, and replace then with a single class: First of all let's refactor our `UserList` and `UserDetail` views into a single `UserViewSet`. We can remove the two views, and replace then with a single class:
from rest_framework import viewsets
class UserViewSet(viewsets.ReadOnlyModelViewSet): class UserViewSet(viewsets.ReadOnlyModelViewSet):
""" """
@ -23,7 +25,6 @@ Here we've used `ReadOnlyModelViewSet` class to automatically provide the defaul
Next we're going to replace the `SnippetList`, `SnippetDetail` and `SnippetHighlight` view classes. We can remove the three views, and again replace them with a single class. Next we're going to replace the `SnippetList`, `SnippetDetail` and `SnippetHighlight` view classes. We can remove the three views, and again replace them with a single class.
from rest_framework import viewsets
from rest_framework.decorators import link from rest_framework.decorators import link
class SnippetViewSet(viewsets.ModelViewSet): class SnippetViewSet(viewsets.ModelViewSet):
@ -73,7 +74,7 @@ In the `urls.py` file we bind our `ViewSet` classes into a set of concrete views
}) })
snippet_highlight = SnippetViewSet.as_view({ snippet_highlight = SnippetViewSet.as_view({
'get': 'highlight' 'get': 'highlight'
}) }, renderer_classes=[renderers.StaticHTMLRenderer])
user_list = UserViewSet.as_view({ user_list = UserViewSet.as_view({
'get': 'list' 'get': 'list'
}) })

View File

@ -230,8 +230,9 @@ class OAuthAuthentication(BaseAuthentication):
try: try:
consumer_key = oauth_request.get_parameter('oauth_consumer_key') consumer_key = oauth_request.get_parameter('oauth_consumer_key')
consumer = oauth_provider_store.get_consumer(request, oauth_request, consumer_key) consumer = oauth_provider_store.get_consumer(request, oauth_request, consumer_key)
except oauth_provider.store.InvalidConsumerError as err: except oauth_provider.store.InvalidConsumerError:
raise exceptions.AuthenticationFailed(err) msg = 'Invalid consumer token: %s' % oauth_request.get_parameter('oauth_consumer_key')
raise exceptions.AuthenticationFailed(msg)
if consumer.status != oauth_provider.consts.ACCEPTED: if consumer.status != oauth_provider.consts.ACCEPTED:
msg = 'Invalid consumer key status: %s' % consumer.get_status_display() msg = 'Invalid consumer key status: %s' % consumer.get_status_display()

View File

@ -86,10 +86,3 @@ class Throttled(APIException):
self.detail = format % (self.wait, self.wait != 1 and 's' or '') self.detail = format % (self.wait, self.wait != 1 and 's' or '')
else: else:
self.detail = detail or self.default_detail self.detail = detail or self.default_detail
class ConfigurationError(Exception):
"""
Indicates an internal server error.
"""
pass

View File

@ -7,25 +7,24 @@ from __future__ import unicode_literals
import copy import copy
import datetime import datetime
from decimal import Decimal, DecimalException
import inspect import inspect
import re import re
import warnings import warnings
from decimal import Decimal, DecimalException
from django import forms
from django.core import validators from django.core import validators
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.conf import settings from django.conf import settings
from django.db.models.fields import BLANK_CHOICE_DASH from django.db.models.fields import BLANK_CHOICE_DASH
from django import forms
from django.forms import widgets from django.forms import widgets
from django.utils.encoding import is_protected_type from django.utils.encoding import is_protected_type
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from django.utils.datastructures import SortedDict from django.utils.datastructures import SortedDict
from rest_framework import ISO_8601 from rest_framework import ISO_8601
from rest_framework.compat import (timezone, parse_date, parse_datetime, from rest_framework.compat import (
parse_time) timezone, parse_date, parse_datetime, parse_time, BytesIO, six, smart_text,
from rest_framework.compat import BytesIO force_text, is_non_str_iterable
from rest_framework.compat import six )
from rest_framework.compat import smart_text, force_text, is_non_str_iterable
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
@ -256,6 +255,12 @@ class WritableField(Field):
widget = widget() widget = widget()
self.widget = widget self.widget = widget
def __deepcopy__(self, memo):
result = copy.copy(self)
memo[id(self)] = result
result.validators = self.validators[:]
return result
def validate(self, value): def validate(self, value):
if value in validators.EMPTY_VALUES and self.required: if value in validators.EMPTY_VALUES and self.required:
raise ValidationError(self.error_messages['required']) raise ValidationError(self.error_messages['required'])
@ -428,13 +433,6 @@ class SlugField(CharField):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(SlugField, self).__init__(*args, **kwargs) super(SlugField, self).__init__(*args, **kwargs)
def __deepcopy__(self, memo):
result = copy.copy(self)
memo[id(self)] = result
#result.widget = copy.deepcopy(self.widget, memo)
result.validators = self.validators[:]
return result
class ChoiceField(WritableField): class ChoiceField(WritableField):
type_name = 'ChoiceField' type_name = 'ChoiceField'
@ -503,13 +501,6 @@ class EmailField(CharField):
return None return None
return ret.strip() return ret.strip()
def __deepcopy__(self, memo):
result = copy.copy(self)
memo[id(self)] = result
#result.widget = copy.deepcopy(self.widget, memo)
result.validators = self.validators[:]
return result
class RegexField(CharField): class RegexField(CharField):
type_name = 'RegexField' type_name = 'RegexField'
@ -534,12 +525,6 @@ class RegexField(CharField):
regex = property(_get_regex, _set_regex) regex = property(_get_regex, _set_regex)
def __deepcopy__(self, memo):
result = copy.copy(self)
memo[id(self)] = result
result.validators = self.validators[:]
return result
class DateField(WritableField): class DateField(WritableField):
type_name = 'DateField' type_name = 'DateField'

View File

@ -285,7 +285,7 @@ class GenericAPIView(views.APIView):
) )
filter_kwargs = {self.slug_field: slug} filter_kwargs = {self.slug_field: slug}
else: else:
raise exceptions.ConfigurationError( raise ImproperlyConfigured(
'Expected view %s to be called with a URL keyword argument ' 'Expected view %s to be called with a URL keyword argument '
'named "%s". Fix your URL conf, or set the `.lookup_field` ' 'named "%s". Fix your URL conf, or set the `.lookup_field` '
'attribute on the view correctly.' % 'attribute on the view correctly.' %

View File

@ -11,6 +11,7 @@ from __future__ import unicode_literals
import copy import copy
import json import json
from django import forms from django import forms
from django.core.exceptions import ImproperlyConfigured
from django.http.multipartparser import parse_header from django.http.multipartparser import parse_header
from django.template import RequestContext, loader, Template from django.template import RequestContext, loader, Template
from django.utils.xmlutils import SimplerXMLGenerator from django.utils.xmlutils import SimplerXMLGenerator
@ -18,7 +19,6 @@ from rest_framework.compat import StringIO
from rest_framework.compat import six from rest_framework.compat import six
from rest_framework.compat import smart_text from rest_framework.compat import smart_text
from rest_framework.compat import yaml from rest_framework.compat import yaml
from rest_framework.exceptions import ConfigurationError
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.request import clone_request from rest_framework.request import clone_request
from rest_framework.utils import encoders from rest_framework.utils import encoders
@ -270,7 +270,7 @@ class TemplateHTMLRenderer(BaseRenderer):
return [self.template_name] return [self.template_name]
elif hasattr(view, 'get_template_names'): elif hasattr(view, 'get_template_names'):
return view.get_template_names() return view.get_template_names()
raise ConfigurationError('Returned a template response with no template_name') raise ImproperlyConfigured('Returned a template response with no template_name')
def get_exception_template(self, response): def get_exception_template(self, response):
template_names = [name % {'status_code': response.status_code} template_names = [name % {'status_code': response.status_code}

View File

@ -215,6 +215,7 @@ class DefaultRouter(SimpleRouter):
""" """
include_root_view = True include_root_view = True
include_format_suffixes = True include_format_suffixes = True
root_view_name = 'api-root'
def get_api_root_view(self): def get_api_root_view(self):
""" """
@ -244,7 +245,7 @@ class DefaultRouter(SimpleRouter):
urls = [] urls = []
if self.include_root_view: if self.include_root_view:
root_url = url(r'^$', self.get_api_root_view(), name='api-root') root_url = url(r'^$', self.get_api_root_view(), name=self.root_view_name)
urls.append(root_url) urls.append(root_url)
default_urls = super(DefaultRouter, self).get_urls() default_urls = super(DefaultRouter, self).get_urls()

View File

@ -428,6 +428,47 @@ class OAuthTests(TestCase):
response = self.csrf_client.post('/oauth-with-scope/', params) response = self.csrf_client.post('/oauth-with-scope/', params)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
@unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
@unittest.skipUnless(oauth, 'oauth2 not installed')
def test_bad_consumer_key(self):
"""Ensure POSTing using HMAC_SHA1 signature method passes"""
params = {
'oauth_version': "1.0",
'oauth_nonce': oauth.generate_nonce(),
'oauth_timestamp': int(time.time()),
'oauth_token': self.token.key,
'oauth_consumer_key': 'badconsumerkey'
}
req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params)
signature_method = oauth.SignatureMethod_HMAC_SHA1()
req.sign_request(signature_method, self.consumer, self.token)
auth = req.to_header()["Authorization"]
response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 401)
@unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
@unittest.skipUnless(oauth, 'oauth2 not installed')
def test_bad_token_key(self):
"""Ensure POSTing using HMAC_SHA1 signature method passes"""
params = {
'oauth_version': "1.0",
'oauth_nonce': oauth.generate_nonce(),
'oauth_timestamp': int(time.time()),
'oauth_token': 'badtokenkey',
'oauth_consumer_key': self.consumer.key
}
req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params)
signature_method = oauth.SignatureMethod_HMAC_SHA1()
req.sign_request(signature_method, self.consumer, self.token)
auth = req.to_header()["Authorization"]
response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 401)
class OAuth2Tests(TestCase): class OAuth2Tests(TestCase):
"""OAuth 2.0 authentication""" """OAuth 2.0 authentication"""

View File

@ -6,7 +6,7 @@ from rest_framework import serializers, viewsets
from rest_framework.compat import include, patterns, url from rest_framework.compat import include, patterns, url
from rest_framework.decorators import link, action from rest_framework.decorators import link, action
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.routers import SimpleRouter from rest_framework.routers import SimpleRouter, DefaultRouter
factory = RequestFactory() factory = RequestFactory()
@ -148,3 +148,17 @@ class TestTrailingSlash(TestCase):
expected = ['^notes$', '^notes/(?P<pk>[^/]+)$'] expected = ['^notes$', '^notes/(?P<pk>[^/]+)$']
for idx in range(len(expected)): for idx in range(len(expected)):
self.assertEqual(expected[idx], self.urls[idx].regex.pattern) self.assertEqual(expected[idx], self.urls[idx].regex.pattern)
class TestNameableRoot(TestCase):
def setUp(self):
class NoteViewSet(viewsets.ModelViewSet):
model = RouterTestModel
self.router = DefaultRouter()
self.router.root_view_name = 'nameable-root'
self.router.register(r'notes', NoteViewSet)
self.urls = self.router.urls
def test_router_has_custom_name(self):
expected = 'nameable-root'
self.assertEqual(expected, self.urls[0].name)

View File

@ -3,7 +3,7 @@ Provides various throttling policies.
""" """
from __future__ import unicode_literals from __future__ import unicode_literals
from django.core.cache import cache from django.core.cache import cache
from rest_framework import exceptions from django.core.exceptions import ImproperlyConfigured
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
import time import time
@ -65,13 +65,13 @@ class SimpleRateThrottle(BaseThrottle):
if not getattr(self, 'scope', None): if not getattr(self, 'scope', None):
msg = ("You must set either `.scope` or `.rate` for '%s' throttle" % msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
self.__class__.__name__) self.__class__.__name__)
raise exceptions.ConfigurationError(msg) raise ImproperlyConfigured(msg)
try: try:
return self.settings.DEFAULT_THROTTLE_RATES[self.scope] return self.settings.DEFAULT_THROTTLE_RATES[self.scope]
except KeyError: except KeyError:
msg = "No default throttle rate set for '%s' scope" % self.scope msg = "No default throttle rate set for '%s' scope" % self.scope
raise exceptions.ConfigurationError(msg) raise ImproperlyConfigured(msg)
def parse_rate(self, rate): def parse_rate(self, rate):
""" """

View File

@ -304,10 +304,10 @@ class APIView(View):
`.dispatch()` is pretty much the same as Django's regular dispatch, `.dispatch()` is pretty much the same as Django's regular dispatch,
but with extra hooks for startup, finalize, and exception handling. but with extra hooks for startup, finalize, and exception handling.
""" """
request = self.initialize_request(request, *args, **kwargs)
self.request = request
self.args = args self.args = args
self.kwargs = kwargs self.kwargs = kwargs
request = self.initialize_request(request, *args, **kwargs)
self.request = request
self.headers = self.default_response_headers # deprecate? self.headers = self.default_response_headers # deprecate?
try: try:
@ -341,8 +341,15 @@ class APIView(View):
Return a dictionary of metadata about the view. Return a dictionary of metadata about the view.
Used to return responses for OPTIONS requests. Used to return responses for OPTIONS requests.
""" """
# This is used by ViewSets to disambiguate instance vs list views
view_name_suffix = getattr(self, 'suffix', None)
# By default we can't provide any form-like information, however the
# generic views override this implementation and add additional
# information for POST and PUT methods, based on the serializer.
ret = SortedDict() ret = SortedDict()
ret['name'] = get_view_name(self.__class__) ret['name'] = get_view_name(self.__class__, view_name_suffix)
ret['description'] = get_view_description(self.__class__) ret['description'] = get_view_description(self.__class__)
ret['renders'] = [renderer.media_type for renderer in self.renderer_classes] ret['renders'] = [renderer.media_type for renderer in self.renderer_classes]
ret['parses'] = [parser.media_type for parser in self.parser_classes] ret['parses'] = [parser.media_type for parser in self.parser_classes]