Merge pull request #1469 from entrouvert/master

authentication: allow all transport modes of access token in OAuth2Authentication
This commit is contained in:
Tom Christie 2014-03-21 12:23:49 +00:00
commit 17f0871736
2 changed files with 36 additions and 2 deletions

View File

@ -6,6 +6,7 @@ import base64
from django.contrib.auth import authenticate from django.contrib.auth import authenticate
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.conf import settings
from rest_framework import exceptions, HTTP_HEADER_ENCODING from rest_framework import exceptions, HTTP_HEADER_ENCODING
from rest_framework.compat import CsrfViewMiddleware from rest_framework.compat import CsrfViewMiddleware
from rest_framework.compat import oauth, oauth_provider, oauth_provider_store from rest_framework.compat import oauth, oauth_provider, oauth_provider_store
@ -291,6 +292,7 @@ class OAuth2Authentication(BaseAuthentication):
OAuth 2 authentication backend using `django-oauth2-provider` OAuth 2 authentication backend using `django-oauth2-provider`
""" """
www_authenticate_realm = 'api' www_authenticate_realm = 'api'
allow_query_params_token = settings.DEBUG
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(OAuth2Authentication, self).__init__(*args, **kwargs) super(OAuth2Authentication, self).__init__(*args, **kwargs)
@ -308,7 +310,13 @@ class OAuth2Authentication(BaseAuthentication):
auth = get_authorization_header(request).split() auth = get_authorization_header(request).split()
if not auth or auth[0].lower() != b'bearer': if auth and auth[0].lower() == b'bearer':
access_token = auth[1]
elif 'access_token' in request.POST:
access_token = request.POST['access_token']
elif 'access_token' in request.GET and self.allow_query_params_token:
access_token = request.GET['access_token']
else:
return None return None
if len(auth) == 1: if len(auth) == 1:
@ -318,7 +326,7 @@ class OAuth2Authentication(BaseAuthentication):
msg = 'Invalid bearer header. Token string should not contain spaces.' msg = 'Invalid bearer header. Token string should not contain spaces.'
raise exceptions.AuthenticationFailed(msg) raise exceptions.AuthenticationFailed(msg)
return self.authenticate_credentials(request, auth[1]) return self.authenticate_credentials(request, access_token)
def authenticate_credentials(self, request, access_token): def authenticate_credentials(self, request, access_token):
""" """

View File

@ -3,6 +3,7 @@ from django.contrib.auth.models import User
from django.http import HttpResponse from django.http import HttpResponse
from django.test import TestCase from django.test import TestCase
from django.utils import unittest from django.utils import unittest
from django.utils.http import urlencode
from rest_framework import HTTP_HEADER_ENCODING from rest_framework import HTTP_HEADER_ENCODING
from rest_framework import exceptions from rest_framework import exceptions
from rest_framework import permissions from rest_framework import permissions
@ -53,10 +54,14 @@ urlpatterns = patterns('',
permission_classes=[permissions.TokenHasReadWriteScope])) permission_classes=[permissions.TokenHasReadWriteScope]))
) )
class OAuth2AuthenticationDebug(OAuth2Authentication):
allow_query_params_token = True
if oauth2_provider is not None: if oauth2_provider is not None:
urlpatterns += patterns('', urlpatterns += patterns('',
url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')), url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')),
url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])), url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])),
url(r'^oauth2-test-debug/$', MockView.as_view(authentication_classes=[OAuth2AuthenticationDebug])),
url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication], url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication],
permission_classes=[permissions.TokenHasReadWriteScope])), permission_classes=[permissions.TokenHasReadWriteScope])),
) )
@ -545,6 +550,27 @@ class OAuth2Tests(TestCase):
response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_post_form_passing_auth_url_transport(self):
"""Ensure GETing form over OAuth with correct client credentials in form data succeed"""
response = self.csrf_client.post('/oauth2-test/',
data={'access_token': self.access_token.token})
self.assertEqual(response.status_code, 200)
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_get_form_passing_auth_url_transport(self):
"""Ensure GETing form over OAuth with correct client credentials in query succeed when DEBUG is True"""
query = urlencode({'access_token': self.access_token.token})
response = self.csrf_client.get('/oauth2-test-debug/?%s' % query)
self.assertEqual(response.status_code, 200)
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_get_form_failing_auth_url_transport(self):
"""Ensure GETing form over OAuth with correct client credentials in query fails when DEBUG is False"""
query = urlencode({'access_token': self.access_token.token})
response = self.csrf_client.get('/oauth2-test/?%s' % query)
self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
def test_post_form_passing_auth(self): def test_post_form_passing_auth(self):
"""Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF""" """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""