This commit is contained in:
Tom Christie 2013-06-10 09:06:15 +01:00
commit 351e172d45
8 changed files with 66 additions and 9 deletions

View File

@ -39,7 +39,7 @@ Declaring a serializer looks very similar to declaring a form:
an existing model instance, or create a new model instance. an existing model instance, or create a new model instance.
""" """
if instance is not None: if instance is not None:
instance.title = attrs.get('title', instance.title) instance.email = attrs.get('email', instance.email)
instance.content = attrs.get('content', instance.content) instance.content = attrs.get('content', instance.content)
instance.created = attrs.get('created', instance.created) instance.created = attrs.get('created', instance.created)
return instance return instance

View File

@ -209,8 +209,6 @@ To create a base viewset class that provides `create`, `list` and `retrieve` ope
mixins.ListMixin, mixins.ListMixin,
mixins.RetrieveMixin, mixins.RetrieveMixin,
viewsets.GenericViewSet): viewsets.GenericViewSet):
pass
""" """
A viewset that provides `retrieve`, `update`, and `list` actions. A viewset that provides `retrieve`, `update`, and `list` actions.

View File

@ -140,6 +140,7 @@ The following people have helped make REST framework great.
* Alex Burgel - [aburgel] * Alex Burgel - [aburgel]
* David Medina - [copitux] * David Medina - [copitux]
* Areski Belaid - [areski] * Areski Belaid - [areski]
* Ethan Freman - [mindlace]
Many thanks to everyone who's contributed to the project. Many thanks to everyone who's contributed to the project.
@ -316,3 +317,4 @@ You can also contact [@_tomchristie][twitter] directly on twitter.
[aburgel]: https://github.com/aburgel [aburgel]: https://github.com/aburgel
[copitux]: https://github.com/copitux [copitux]: https://github.com/copitux
[areski]: https://github.com/areski [areski]: https://github.com/areski
[mindlace]: https://github.com/mindlace

View File

@ -230,8 +230,9 @@ class OAuthAuthentication(BaseAuthentication):
try: try:
consumer_key = oauth_request.get_parameter('oauth_consumer_key') consumer_key = oauth_request.get_parameter('oauth_consumer_key')
consumer = oauth_provider_store.get_consumer(request, oauth_request, consumer_key) consumer = oauth_provider_store.get_consumer(request, oauth_request, consumer_key)
except oauth_provider.store.InvalidConsumerError as err: except oauth_provider.store.InvalidConsumerError:
raise exceptions.AuthenticationFailed(err) msg = 'Invalid consumer token: %s' % oauth_request.get_parameter('oauth_consumer_key')
raise exceptions.AuthenticationFailed(msg)
if consumer.status != oauth_provider.consts.ACCEPTED: if consumer.status != oauth_provider.consts.ACCEPTED:
msg = 'Invalid consumer key status: %s' % consumer.get_status_display() msg = 'Invalid consumer key status: %s' % consumer.get_status_display()

View File

@ -215,6 +215,7 @@ class DefaultRouter(SimpleRouter):
""" """
include_root_view = True include_root_view = True
include_format_suffixes = True include_format_suffixes = True
root_view_name = 'api-root'
def get_api_root_view(self): def get_api_root_view(self):
""" """
@ -244,7 +245,7 @@ class DefaultRouter(SimpleRouter):
urls = [] urls = []
if self.include_root_view: if self.include_root_view:
root_url = url(r'^$', self.get_api_root_view(), name='api-root') root_url = url(r'^$', self.get_api_root_view(), name=self.root_view_name)
urls.append(root_url) urls.append(root_url)
default_urls = super(DefaultRouter, self).get_urls() default_urls = super(DefaultRouter, self).get_urls()

View File

@ -428,6 +428,47 @@ class OAuthTests(TestCase):
response = self.csrf_client.post('/oauth-with-scope/', params) response = self.csrf_client.post('/oauth-with-scope/', params)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
@unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
@unittest.skipUnless(oauth, 'oauth2 not installed')
def test_bad_consumer_key(self):
"""Ensure POSTing using HMAC_SHA1 signature method passes"""
params = {
'oauth_version': "1.0",
'oauth_nonce': oauth.generate_nonce(),
'oauth_timestamp': int(time.time()),
'oauth_token': self.token.key,
'oauth_consumer_key': 'badconsumerkey'
}
req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params)
signature_method = oauth.SignatureMethod_HMAC_SHA1()
req.sign_request(signature_method, self.consumer, self.token)
auth = req.to_header()["Authorization"]
response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 401)
@unittest.skipUnless(oauth_provider, 'django-oauth-plus not installed')
@unittest.skipUnless(oauth, 'oauth2 not installed')
def test_bad_token_key(self):
"""Ensure POSTing using HMAC_SHA1 signature method passes"""
params = {
'oauth_version': "1.0",
'oauth_nonce': oauth.generate_nonce(),
'oauth_timestamp': int(time.time()),
'oauth_token': 'badtokenkey',
'oauth_consumer_key': self.consumer.key
}
req = oauth.Request(method="POST", url="http://testserver/oauth/", parameters=params)
signature_method = oauth.SignatureMethod_HMAC_SHA1()
req.sign_request(signature_method, self.consumer, self.token)
auth = req.to_header()["Authorization"]
response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 401)
class OAuth2Tests(TestCase): class OAuth2Tests(TestCase):
"""OAuth 2.0 authentication""" """OAuth 2.0 authentication"""

View File

@ -6,7 +6,7 @@ from rest_framework import serializers, viewsets
from rest_framework.compat import include, patterns, url from rest_framework.compat import include, patterns, url
from rest_framework.decorators import link, action from rest_framework.decorators import link, action
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.routers import SimpleRouter from rest_framework.routers import SimpleRouter, DefaultRouter
factory = RequestFactory() factory = RequestFactory()
@ -148,3 +148,17 @@ class TestTrailingSlash(TestCase):
expected = ['^notes$', '^notes/(?P<pk>[^/]+)$'] expected = ['^notes$', '^notes/(?P<pk>[^/]+)$']
for idx in range(len(expected)): for idx in range(len(expected)):
self.assertEqual(expected[idx], self.urls[idx].regex.pattern) self.assertEqual(expected[idx], self.urls[idx].regex.pattern)
class TestNameableRoot(TestCase):
def setUp(self):
class NoteViewSet(viewsets.ModelViewSet):
model = RouterTestModel
self.router = DefaultRouter()
self.router.root_view_name = 'nameable-root'
self.router.register(r'notes', NoteViewSet)
self.urls = self.router.urls
def test_router_has_custom_name(self):
expected = 'nameable-root'
self.assertEqual(expected, self.urls[0].name)

View File

@ -304,10 +304,10 @@ class APIView(View):
`.dispatch()` is pretty much the same as Django's regular dispatch, `.dispatch()` is pretty much the same as Django's regular dispatch,
but with extra hooks for startup, finalize, and exception handling. but with extra hooks for startup, finalize, and exception handling.
""" """
request = self.initialize_request(request, *args, **kwargs)
self.request = request
self.args = args self.args = args
self.kwargs = kwargs self.kwargs = kwargs
request = self.initialize_request(request, *args, **kwargs)
self.request = request
self.headers = self.default_response_headers # deprecate? self.headers = self.default_response_headers # deprecate?
try: try: