mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-01-23 15:54:16 +03:00
Merge pull request #1469 from entrouvert/master
authentication: allow all transport modes of access token in OAuth2Authentication
This commit is contained in:
commit
17f0871736
|
@ -6,6 +6,7 @@ import base64
|
|||
|
||||
from django.contrib.auth import authenticate
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.conf import settings
|
||||
from rest_framework import exceptions, HTTP_HEADER_ENCODING
|
||||
from rest_framework.compat import CsrfViewMiddleware
|
||||
from rest_framework.compat import oauth, oauth_provider, oauth_provider_store
|
||||
|
@ -291,6 +292,7 @@ class OAuth2Authentication(BaseAuthentication):
|
|||
OAuth 2 authentication backend using `django-oauth2-provider`
|
||||
"""
|
||||
www_authenticate_realm = 'api'
|
||||
allow_query_params_token = settings.DEBUG
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(OAuth2Authentication, self).__init__(*args, **kwargs)
|
||||
|
@ -308,7 +310,13 @@ class OAuth2Authentication(BaseAuthentication):
|
|||
|
||||
auth = get_authorization_header(request).split()
|
||||
|
||||
if not auth or auth[0].lower() != b'bearer':
|
||||
if auth and auth[0].lower() == b'bearer':
|
||||
access_token = auth[1]
|
||||
elif 'access_token' in request.POST:
|
||||
access_token = request.POST['access_token']
|
||||
elif 'access_token' in request.GET and self.allow_query_params_token:
|
||||
access_token = request.GET['access_token']
|
||||
else:
|
||||
return None
|
||||
|
||||
if len(auth) == 1:
|
||||
|
@ -318,7 +326,7 @@ class OAuth2Authentication(BaseAuthentication):
|
|||
msg = 'Invalid bearer header. Token string should not contain spaces.'
|
||||
raise exceptions.AuthenticationFailed(msg)
|
||||
|
||||
return self.authenticate_credentials(request, auth[1])
|
||||
return self.authenticate_credentials(request, access_token)
|
||||
|
||||
def authenticate_credentials(self, request, access_token):
|
||||
"""
|
||||
|
|
|
@ -3,6 +3,7 @@ from django.contrib.auth.models import User
|
|||
from django.http import HttpResponse
|
||||
from django.test import TestCase
|
||||
from django.utils import unittest
|
||||
from django.utils.http import urlencode
|
||||
from rest_framework import HTTP_HEADER_ENCODING
|
||||
from rest_framework import exceptions
|
||||
from rest_framework import permissions
|
||||
|
@ -53,10 +54,14 @@ urlpatterns = patterns('',
|
|||
permission_classes=[permissions.TokenHasReadWriteScope]))
|
||||
)
|
||||
|
||||
class OAuth2AuthenticationDebug(OAuth2Authentication):
|
||||
allow_query_params_token = True
|
||||
|
||||
if oauth2_provider is not None:
|
||||
urlpatterns += patterns('',
|
||||
url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')),
|
||||
url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])),
|
||||
url(r'^oauth2-test-debug/$', MockView.as_view(authentication_classes=[OAuth2AuthenticationDebug])),
|
||||
url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication],
|
||||
permission_classes=[permissions.TokenHasReadWriteScope])),
|
||||
)
|
||||
|
@ -545,6 +550,27 @@ class OAuth2Tests(TestCase):
|
|||
response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
||||
def test_post_form_passing_auth_url_transport(self):
|
||||
"""Ensure GETing form over OAuth with correct client credentials in form data succeed"""
|
||||
response = self.csrf_client.post('/oauth2-test/',
|
||||
data={'access_token': self.access_token.token})
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
||||
def test_get_form_passing_auth_url_transport(self):
|
||||
"""Ensure GETing form over OAuth with correct client credentials in query succeed when DEBUG is True"""
|
||||
query = urlencode({'access_token': self.access_token.token})
|
||||
response = self.csrf_client.get('/oauth2-test-debug/?%s' % query)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
||||
def test_get_form_failing_auth_url_transport(self):
|
||||
"""Ensure GETing form over OAuth with correct client credentials in query fails when DEBUG is False"""
|
||||
query = urlencode({'access_token': self.access_token.token})
|
||||
response = self.csrf_client.get('/oauth2-test/?%s' % query)
|
||||
self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
|
||||
|
||||
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
||||
def test_post_form_passing_auth(self):
|
||||
"""Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""
|
||||
|
|
Loading…
Reference in New Issue
Block a user