mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-02-02 20:54:42 +03:00
Merge pull request #1469 from entrouvert/master
authentication: allow all transport modes of access token in OAuth2Authentication
This commit is contained in:
commit
17f0871736
|
@ -6,6 +6,7 @@ import base64
|
||||||
|
|
||||||
from django.contrib.auth import authenticate
|
from django.contrib.auth import authenticate
|
||||||
from django.core.exceptions import ImproperlyConfigured
|
from django.core.exceptions import ImproperlyConfigured
|
||||||
|
from django.conf import settings
|
||||||
from rest_framework import exceptions, HTTP_HEADER_ENCODING
|
from rest_framework import exceptions, HTTP_HEADER_ENCODING
|
||||||
from rest_framework.compat import CsrfViewMiddleware
|
from rest_framework.compat import CsrfViewMiddleware
|
||||||
from rest_framework.compat import oauth, oauth_provider, oauth_provider_store
|
from rest_framework.compat import oauth, oauth_provider, oauth_provider_store
|
||||||
|
@ -291,6 +292,7 @@ class OAuth2Authentication(BaseAuthentication):
|
||||||
OAuth 2 authentication backend using `django-oauth2-provider`
|
OAuth 2 authentication backend using `django-oauth2-provider`
|
||||||
"""
|
"""
|
||||||
www_authenticate_realm = 'api'
|
www_authenticate_realm = 'api'
|
||||||
|
allow_query_params_token = settings.DEBUG
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(OAuth2Authentication, self).__init__(*args, **kwargs)
|
super(OAuth2Authentication, self).__init__(*args, **kwargs)
|
||||||
|
@ -308,7 +310,13 @@ class OAuth2Authentication(BaseAuthentication):
|
||||||
|
|
||||||
auth = get_authorization_header(request).split()
|
auth = get_authorization_header(request).split()
|
||||||
|
|
||||||
if not auth or auth[0].lower() != b'bearer':
|
if auth and auth[0].lower() == b'bearer':
|
||||||
|
access_token = auth[1]
|
||||||
|
elif 'access_token' in request.POST:
|
||||||
|
access_token = request.POST['access_token']
|
||||||
|
elif 'access_token' in request.GET and self.allow_query_params_token:
|
||||||
|
access_token = request.GET['access_token']
|
||||||
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if len(auth) == 1:
|
if len(auth) == 1:
|
||||||
|
@ -318,7 +326,7 @@ class OAuth2Authentication(BaseAuthentication):
|
||||||
msg = 'Invalid bearer header. Token string should not contain spaces.'
|
msg = 'Invalid bearer header. Token string should not contain spaces.'
|
||||||
raise exceptions.AuthenticationFailed(msg)
|
raise exceptions.AuthenticationFailed(msg)
|
||||||
|
|
||||||
return self.authenticate_credentials(request, auth[1])
|
return self.authenticate_credentials(request, access_token)
|
||||||
|
|
||||||
def authenticate_credentials(self, request, access_token):
|
def authenticate_credentials(self, request, access_token):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -3,6 +3,7 @@ from django.contrib.auth.models import User
|
||||||
from django.http import HttpResponse
|
from django.http import HttpResponse
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
from django.utils import unittest
|
from django.utils import unittest
|
||||||
|
from django.utils.http import urlencode
|
||||||
from rest_framework import HTTP_HEADER_ENCODING
|
from rest_framework import HTTP_HEADER_ENCODING
|
||||||
from rest_framework import exceptions
|
from rest_framework import exceptions
|
||||||
from rest_framework import permissions
|
from rest_framework import permissions
|
||||||
|
@ -53,10 +54,14 @@ urlpatterns = patterns('',
|
||||||
permission_classes=[permissions.TokenHasReadWriteScope]))
|
permission_classes=[permissions.TokenHasReadWriteScope]))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class OAuth2AuthenticationDebug(OAuth2Authentication):
|
||||||
|
allow_query_params_token = True
|
||||||
|
|
||||||
if oauth2_provider is not None:
|
if oauth2_provider is not None:
|
||||||
urlpatterns += patterns('',
|
urlpatterns += patterns('',
|
||||||
url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')),
|
url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')),
|
||||||
url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])),
|
url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])),
|
||||||
|
url(r'^oauth2-test-debug/$', MockView.as_view(authentication_classes=[OAuth2AuthenticationDebug])),
|
||||||
url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication],
|
url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication],
|
||||||
permission_classes=[permissions.TokenHasReadWriteScope])),
|
permission_classes=[permissions.TokenHasReadWriteScope])),
|
||||||
)
|
)
|
||||||
|
@ -545,6 +550,27 @@ class OAuth2Tests(TestCase):
|
||||||
response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
|
response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
|
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
||||||
|
def test_post_form_passing_auth_url_transport(self):
|
||||||
|
"""Ensure GETing form over OAuth with correct client credentials in form data succeed"""
|
||||||
|
response = self.csrf_client.post('/oauth2-test/',
|
||||||
|
data={'access_token': self.access_token.token})
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
|
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
||||||
|
def test_get_form_passing_auth_url_transport(self):
|
||||||
|
"""Ensure GETing form over OAuth with correct client credentials in query succeed when DEBUG is True"""
|
||||||
|
query = urlencode({'access_token': self.access_token.token})
|
||||||
|
response = self.csrf_client.get('/oauth2-test-debug/?%s' % query)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
|
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
||||||
|
def test_get_form_failing_auth_url_transport(self):
|
||||||
|
"""Ensure GETing form over OAuth with correct client credentials in query fails when DEBUG is False"""
|
||||||
|
query = urlencode({'access_token': self.access_token.token})
|
||||||
|
response = self.csrf_client.get('/oauth2-test/?%s' % query)
|
||||||
|
self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
|
||||||
|
|
||||||
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
||||||
def test_post_form_passing_auth(self):
|
def test_post_form_passing_auth(self):
|
||||||
"""Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""
|
"""Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""
|
||||||
|
|
Loading…
Reference in New Issue
Block a user