diff --git a/docs/api-guide/generic-views.md b/docs/api-guide/generic-views.md index 20b9440ba..cd1bc7a1c 100755 --- a/docs/api-guide/generic-views.md +++ b/docs/api-guide/generic-views.md @@ -60,7 +60,7 @@ The following attributes control the basic view behavior. * `queryset` - The queryset that should be used for returning objects from this view. Typically, you must either set this attribute, or override the `get_queryset()` method. * `serializer_class` - The serializer class that should be used for validating and deserializing input, and for serializing output. Typically, you must either set this attribute, or override the `get_serializer_class()` method. -* `lookup_field` - The field that should be used to lookup individual model instances. Defaults to `'pk'`. The URL conf should include a keyword argument corresponding to this value. More complex lookup styles can be supported by overriding the `get_object()` method. +* `lookup_field` - The field that should be used to lookup individual model instances. Defaults to `'pk'`. The URL conf should include a keyword argument corresponding to this value. More complex lookup styles can be supported by overriding the `get_object()` method. Note that when using hyperlinked APIs you'll need to ensure that *both* the API views *and* the serializer classes use lookup fields that correctly correspond with the URL conf. **Shortcuts**: @@ -131,7 +131,7 @@ You may want to override this method to provide more complex behavior such as mo For example: def get_paginate_by(self): - self.request.accepted_renderer.format == 'html': + if self.request.accepted_renderer.format == 'html': return 20 return 100 diff --git a/docs/api-guide/serializers.md b/docs/api-guide/serializers.md index 44ee7e39b..4c02d5301 100644 --- a/docs/api-guide/serializers.md +++ b/docs/api-guide/serializers.md @@ -39,7 +39,7 @@ Declaring a serializer looks very similar to declaring a form: an existing model instance, or create a new model instance. """ if instance is not None: - instance.title = attrs.get('title', instance.title) + instance.email = attrs.get('email', instance.email) instance.content = attrs.get('content', instance.content) instance.created = attrs.get('created', instance.created) return instance @@ -387,7 +387,7 @@ There needs to be a way of determining which views should be used for hyperlinki By default hyperlinks are expected to correspond to a view name that matches the style `'{model_name}-detail'`, and looks up the instance by a `pk` keyword argument. -You can change the field that is used for object lookups by setting the `lookup_field` option. The value of this option should correspond both with a kwarg in the URL conf, and with an field on the model. For example: +You can change the field that is used for object lookups by setting the `lookup_field` option. The value of this option should correspond both with a kwarg in the URL conf, and with a field on the model. For example: class AccountSerializer(serializers.HyperlinkedModelSerializer): class Meta: @@ -395,6 +395,8 @@ You can change the field that is used for object lookups by setting the `lookup_ fields = ('url', 'account_name', 'users', 'created') lookup_field = 'slug' +Not that the `lookup_field` will be used as the default on *all* hyperlinked fields, including both the URL identity, and any hyperlinked relationships. + For more specfic requirements such as specifying a different lookup for each field, you'll want to set the fields on the serializer explicitly. For example: class AccountSerializer(serializers.HyperlinkedModelSerializer): diff --git a/docs/api-guide/viewsets.md b/docs/api-guide/viewsets.md index 2783da98f..79257e2af 100644 --- a/docs/api-guide/viewsets.md +++ b/docs/api-guide/viewsets.md @@ -209,8 +209,6 @@ To create a base viewset class that provides `create`, `list` and `retrieve` ope mixins.ListMixin, mixins.RetrieveMixin, viewsets.GenericViewSet): - pass - """ A viewset that provides `retrieve`, `update`, and `list` actions. diff --git a/docs/topics/credits.md b/docs/topics/credits.md index bbe209c74..3f0ee429c 100644 --- a/docs/topics/credits.md +++ b/docs/topics/credits.md @@ -139,6 +139,9 @@ The following people have helped make REST framework great. * Pascal Borreli - [pborreli] * Alex Burgel - [aburgel] * David Medina - [copitux] +* Areski Belaid - [areski] +* Ethan Freman - [mindlace] +* David Sanders - [davesque] Many thanks to everyone who's contributed to the project. @@ -314,3 +317,7 @@ You can also contact [@_tomchristie][twitter] directly on twitter. [pborreli]: https://github.com/pborreli [aburgel]: https://github.com/aburgel [copitux]: https://github.com/copitux +[areski]: https://github.com/areski +[mindlace]: https://github.com/mindlace +[davesque]: https://github.com/davesque + diff --git a/docs/tutorial/6-viewsets-and-routers.md b/docs/tutorial/6-viewsets-and-routers.md index 4ed10e828..f16add39d 100644 --- a/docs/tutorial/6-viewsets-and-routers.md +++ b/docs/tutorial/6-viewsets-and-routers.md @@ -10,7 +10,9 @@ A `ViewSet` class is only bound to a set of method handlers at the last moment, Let's take our current set of views, and refactor them into view sets. -First of all let's refactor our `UserListView` and `UserDetailView` views into a single `UserViewSet`. We can remove the two views, and replace then with a single class: +First of all let's refactor our `UserList` and `UserDetail` views into a single `UserViewSet`. We can remove the two views, and replace then with a single class: + + from rest_framework import viewsets class UserViewSet(viewsets.ReadOnlyModelViewSet): """ @@ -23,15 +25,14 @@ Here we've used `ReadOnlyModelViewSet` class to automatically provide the defaul Next we're going to replace the `SnippetList`, `SnippetDetail` and `SnippetHighlight` view classes. We can remove the three views, and again replace them with a single class. - from rest_framework import viewsets from rest_framework.decorators import link class SnippetViewSet(viewsets.ModelViewSet): """ This viewset automatically provides `list`, `create`, `retrieve`, `update` and `destroy` actions. - - Additionally we also provide an extra `highlight` action. + + Additionally we also provide an extra `highlight` action. """ queryset = Snippet.objects.all() serializer_class = SnippetSerializer @@ -73,7 +74,7 @@ In the `urls.py` file we bind our `ViewSet` classes into a set of concrete views }) snippet_highlight = SnippetViewSet.as_view({ 'get': 'highlight' - }) + }, renderer_classes=[renderers.StaticHTMLRenderer]) user_list = UserViewSet.as_view({ 'get': 'list' }) @@ -107,7 +108,7 @@ Here's our re-wired `urls.py` file. router = DefaultRouter() router.register(r'snippets', views.SnippetViewSet) router.register(r'users', views.UserViewSet) - + # The API URLs are now determined automatically by the router. # Additionally, we include the login URLs for the browseable API. urlpatterns = patterns('', @@ -131,7 +132,7 @@ With an incredibly small amount of code, we've now got a complete pastebin Web A We've walked through each step of the design process, and seen how if we need to customize anything we can gradually work our way down to simply using regular Django views. -You can review the final [tutorial code][repo] on GitHub, or try out a live example in [the sandbox][sandbox]. +You can review the final [tutorial code][repo] on GitHub, or try out a live example in [the sandbox][sandbox]. ## Onwards and upwards diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index 9caca7889..f659a172e 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -230,8 +230,9 @@ class OAuthAuthentication(BaseAuthentication): try: consumer_key = oauth_request.get_parameter('oauth_consumer_key') consumer = oauth_provider_store.get_consumer(request, oauth_request, consumer_key) - except oauth_provider.store.InvalidConsumerError as err: - raise exceptions.AuthenticationFailed(err) + except oauth_provider.store.InvalidConsumerError: + msg = 'Invalid consumer token: %s' % oauth_request.get_parameter('oauth_consumer_key') + raise exceptions.AuthenticationFailed(msg) if consumer.status != oauth_provider.consts.ACCEPTED: msg = 'Invalid consumer key status: %s' % consumer.get_status_display() diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index 0c96ecdd5..425a72149 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -86,10 +86,3 @@ class Throttled(APIException): self.detail = format % (self.wait, self.wait != 1 and 's' or '') else: self.detail = detail or self.default_detail - - -class ConfigurationError(Exception): - """ - Indicates an internal server error. - """ - pass diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 535aa2ac8..35848b4ce 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -7,25 +7,24 @@ from __future__ import unicode_literals import copy import datetime -from decimal import Decimal, DecimalException import inspect import re import warnings +from decimal import Decimal, DecimalException +from django import forms from django.core import validators from django.core.exceptions import ValidationError from django.conf import settings from django.db.models.fields import BLANK_CHOICE_DASH -from django import forms from django.forms import widgets from django.utils.encoding import is_protected_type from django.utils.translation import ugettext_lazy as _ from django.utils.datastructures import SortedDict from rest_framework import ISO_8601 -from rest_framework.compat import (timezone, parse_date, parse_datetime, - parse_time) -from rest_framework.compat import BytesIO -from rest_framework.compat import six -from rest_framework.compat import smart_text, force_text, is_non_str_iterable +from rest_framework.compat import ( + timezone, parse_date, parse_datetime, parse_time, BytesIO, six, smart_text, + force_text, is_non_str_iterable +) from rest_framework.settings import api_settings @@ -256,6 +255,12 @@ class WritableField(Field): widget = widget() self.widget = widget + def __deepcopy__(self, memo): + result = copy.copy(self) + memo[id(self)] = result + result.validators = self.validators[:] + return result + def validate(self, value): if value in validators.EMPTY_VALUES and self.required: raise ValidationError(self.error_messages['required']) @@ -331,9 +336,13 @@ class ModelField(WritableField): raise ValueError("ModelField requires 'model_field' kwarg") self.min_length = kwargs.pop('min_length', - getattr(self.model_field, 'min_length', None)) + getattr(self.model_field, 'min_length', None)) self.max_length = kwargs.pop('max_length', - getattr(self.model_field, 'max_length', None)) + getattr(self.model_field, 'max_length', None)) + self.min_value = kwargs.pop('min_value', + getattr(self.model_field, 'min_value', None)) + self.max_value = kwargs.pop('max_value', + getattr(self.model_field, 'max_value', None)) super(ModelField, self).__init__(*args, **kwargs) @@ -341,6 +350,10 @@ class ModelField(WritableField): self.validators.append(validators.MinLengthValidator(self.min_length)) if self.max_length is not None: self.validators.append(validators.MaxLengthValidator(self.max_length)) + if self.min_value is not None: + self.validators.append(validators.MinValueValidator(self.min_value)) + if self.max_value is not None: + self.validators.append(validators.MaxValueValidator(self.max_value)) def from_native(self, value): rel = getattr(self.model_field, "rel", None) @@ -428,13 +441,6 @@ class SlugField(CharField): def __init__(self, *args, **kwargs): super(SlugField, self).__init__(*args, **kwargs) - def __deepcopy__(self, memo): - result = copy.copy(self) - memo[id(self)] = result - #result.widget = copy.deepcopy(self.widget, memo) - result.validators = self.validators[:] - return result - class ChoiceField(WritableField): type_name = 'ChoiceField' @@ -503,13 +509,6 @@ class EmailField(CharField): return None return ret.strip() - def __deepcopy__(self, memo): - result = copy.copy(self) - memo[id(self)] = result - #result.widget = copy.deepcopy(self.widget, memo) - result.validators = self.validators[:] - return result - class RegexField(CharField): type_name = 'RegexField' @@ -534,12 +533,6 @@ class RegexField(CharField): regex = property(_get_regex, _set_regex) - def __deepcopy__(self, memo): - result = copy.copy(self) - memo[id(self)] = result - result.validators = self.validators[:] - return result - class DateField(WritableField): type_name = 'DateField' diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 9ccc78980..99e9782e2 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -212,7 +212,7 @@ class GenericAPIView(views.APIView): You may want to override this if you need to provide different serializations depending on the incoming request. - (Eg. admins get full serialization, others get basic serilization) + (Eg. admins get full serialization, others get basic serialization) """ serializer_class = self.serializer_class if serializer_class is not None: @@ -285,7 +285,7 @@ class GenericAPIView(views.APIView): ) filter_kwargs = {self.slug_field: slug} else: - raise exceptions.ConfigurationError( + raise ImproperlyConfigured( 'Expected view %s to be called with a URL keyword argument ' 'named "%s". Fix your URL conf, or set the `.lookup_field` ' 'attribute on the view correctly.' % diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index b2fe43eac..8b2428ad8 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -11,6 +11,7 @@ from __future__ import unicode_literals import copy import json from django import forms +from django.core.exceptions import ImproperlyConfigured from django.http.multipartparser import parse_header from django.template import RequestContext, loader, Template from django.utils.xmlutils import SimplerXMLGenerator @@ -18,7 +19,6 @@ from rest_framework.compat import StringIO from rest_framework.compat import six from rest_framework.compat import smart_text from rest_framework.compat import yaml -from rest_framework.exceptions import ConfigurationError from rest_framework.settings import api_settings from rest_framework.request import clone_request from rest_framework.utils import encoders @@ -270,7 +270,7 @@ class TemplateHTMLRenderer(BaseRenderer): return [self.template_name] elif hasattr(view, 'get_template_names'): return view.get_template_names() - raise ConfigurationError('Returned a template response with no template_name') + raise ImproperlyConfigured('Returned a template response with no template_name') def get_exception_template(self, response): template_names = [name % {'status_code': response.status_code} diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 9764e5692..f70c2cdb1 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -215,6 +215,7 @@ class DefaultRouter(SimpleRouter): """ include_root_view = True include_format_suffixes = True + root_view_name = 'api-root' def get_api_root_view(self): """ @@ -244,7 +245,7 @@ class DefaultRouter(SimpleRouter): urls = [] if self.include_root_view: - root_url = url(r'^$', self.get_api_root_view(), name='api-root') + root_url = url(r'^$', self.get_api_root_view(), name=self.root_view_name) urls.append(root_url) default_urls = super(DefaultRouter, self).get_urls() diff --git a/rest_framework/tests/test_authentication.py b/rest_framework/tests/test_authentication.py index d46ac0798..6a50be064 100644 --- a/rest_framework/tests/test_authentication.py +++ b/rest_framework/tests/test_authentication.py @@ -428,6 +428,47 @@ class OAuthTests(TestCase): response = self.csrf_client.post('/oauth-with-scope/', params) self.assertEqual(response.status_code, 200) + @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') + @unittest.skipUnless(oauth, 'oauth2 not installed') + def test_bad_consumer_key(self): + """Ensure POSTing using HMAC_SHA1 signature method passes""" + params = { + 'oauth_version': "1.0", + 'oauth_nonce': oauth.generate_nonce(), + 'oauth_timestamp': int(time.time()), + 'oauth_token': self.token.key, + 'oauth_consumer_key': 'badconsumerkey' + } + + req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params) + + signature_method = oauth.SignatureMethod_HMAC_SHA1() + req.sign_request(signature_method, self.consumer, self.token) + auth = req.to_header()["Authorization"] + + response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 401) + + @unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed') + @unittest.skipUnless(oauth, 'oauth2 not installed') + def test_bad_token_key(self): + """Ensure POSTing using HMAC_SHA1 signature method passes""" + params = { + 'oauth_version': "1.0", + 'oauth_nonce': oauth.generate_nonce(), + 'oauth_timestamp': int(time.time()), + 'oauth_token': 'badtokenkey', + 'oauth_consumer_key': self.consumer.key + } + + req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params) + + signature_method = oauth.SignatureMethod_HMAC_SHA1() + req.sign_request(signature_method, self.consumer, self.token) + auth = req.to_header()["Authorization"] + + response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 401) class OAuth2Tests(TestCase): """OAuth 2.0 authentication""" diff --git a/rest_framework/tests/test_fields.py b/rest_framework/tests/test_fields.py index 69a0468e0..6836ec86f 100644 --- a/rest_framework/tests/test_fields.py +++ b/rest_framework/tests/test_fields.py @@ -866,3 +866,33 @@ class FieldCallableDefault(TestCase): into = {} field.field_from_native({}, {}, 'field', into) self.assertEqual(into, {'field': 'foo bar'}) + + +class CustomIntegerField(TestCase): + """ + Test that custom fields apply min_value and max_value constraints + """ + def test_custom_fields_can_be_validated_for_value(self): + + class MoneyField(models.PositiveIntegerField): + pass + + class EntryModel(models.Model): + bank = MoneyField(validators=[validators.MaxValueValidator(100)]) + + class EntrySerializer(serializers.ModelSerializer): + class Meta: + model = EntryModel + + entry = EntryModel(bank=1) + + serializer = EntrySerializer(entry, data={"bank": 11}) + self.assertTrue(serializer.is_valid()) + + serializer = EntrySerializer(entry, data={"bank": -1}) + self.assertFalse(serializer.is_valid()) + + serializer = EntrySerializer(entry, data={"bank": 101}) + self.assertFalse(serializer.is_valid()) + + diff --git a/rest_framework/tests/test_routers.py b/rest_framework/tests/test_routers.py index a7534f70b..291142cf9 100644 --- a/rest_framework/tests/test_routers.py +++ b/rest_framework/tests/test_routers.py @@ -6,7 +6,7 @@ from rest_framework import serializers, viewsets from rest_framework.compat import include, patterns, url from rest_framework.decorators import link, action from rest_framework.response import Response -from rest_framework.routers import SimpleRouter +from rest_framework.routers import SimpleRouter, DefaultRouter factory = RequestFactory() @@ -148,3 +148,17 @@ class TestTrailingSlash(TestCase): expected = ['^notes$', '^notes/(?P[^/]+)$'] for idx in range(len(expected)): self.assertEqual(expected[idx], self.urls[idx].regex.pattern) + +class TestNameableRoot(TestCase): + def setUp(self): + class NoteViewSet(viewsets.ModelViewSet): + model = RouterTestModel + self.router = DefaultRouter() + self.router.root_view_name = 'nameable-root' + self.router.register(r'notes', NoteViewSet) + self.urls = self.router.urls + + def test_router_has_custom_name(self): + expected = 'nameable-root' + self.assertEqual(expected, self.urls[0].name) + diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py index 93ea9816c..9d89d1cbc 100644 --- a/rest_framework/throttling.py +++ b/rest_framework/throttling.py @@ -3,7 +3,7 @@ Provides various throttling policies. """ from __future__ import unicode_literals from django.core.cache import cache -from rest_framework import exceptions +from django.core.exceptions import ImproperlyConfigured from rest_framework.settings import api_settings import time @@ -65,13 +65,13 @@ class SimpleRateThrottle(BaseThrottle): if not getattr(self, 'scope', None): msg = ("You must set either `.scope` or `.rate` for '%s' throttle" % self.__class__.__name__) - raise exceptions.ConfigurationError(msg) + raise ImproperlyConfigured(msg) try: return self.settings.DEFAULT_THROTTLE_RATES[self.scope] except KeyError: msg = "No default throttle rate set for '%s' scope" % self.scope - raise exceptions.ConfigurationError(msg) + raise ImproperlyConfigured(msg) def parse_rate(self, rate): """ diff --git a/rest_framework/views.py b/rest_framework/views.py index e1b6705b6..c28d2835f 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -304,10 +304,10 @@ class APIView(View): `.dispatch()` is pretty much the same as Django's regular dispatch, but with extra hooks for startup, finalize, and exception handling. """ - request = self.initialize_request(request, *args, **kwargs) - self.request = request self.args = args self.kwargs = kwargs + request = self.initialize_request(request, *args, **kwargs) + self.request = request self.headers = self.default_response_headers # deprecate? try: @@ -341,8 +341,15 @@ class APIView(View): Return a dictionary of metadata about the view. Used to return responses for OPTIONS requests. """ + + # This is used by ViewSets to disambiguate instance vs list views + view_name_suffix = getattr(self, 'suffix', None) + + # By default we can't provide any form-like information, however the + # generic views override this implementation and add additional + # information for POST and PUT methods, based on the serializer. ret = SortedDict() - ret['name'] = get_view_name(self.__class__) + ret['name'] = get_view_name(self.__class__, view_name_suffix) ret['description'] = get_view_description(self.__class__) ret['renders'] = [renderer.media_type for renderer in self.renderer_classes] ret['parses'] = [parser.media_type for parser in self.parser_classes]