mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-08-04 04:20:12 +03:00
Merge 43a4a66644
into d46d153a99
This commit is contained in:
commit
33e33689e0
|
@ -10,7 +10,11 @@ from django.conf import settings
|
||||||
from rest_framework import exceptions, HTTP_HEADER_ENCODING
|
from rest_framework import exceptions, HTTP_HEADER_ENCODING
|
||||||
from rest_framework.compat import CsrfViewMiddleware
|
from rest_framework.compat import CsrfViewMiddleware
|
||||||
from rest_framework.compat import oauth, oauth_provider, oauth_provider_store
|
from rest_framework.compat import oauth, oauth_provider, oauth_provider_store
|
||||||
from rest_framework.compat import oauth2_provider, provider_now, check_nonce
|
from rest_framework.compat import (oauth2_provider,
|
||||||
|
provider_now,
|
||||||
|
check_nonce,
|
||||||
|
oauth2_constants,
|
||||||
|
)
|
||||||
from rest_framework.authtoken.models import Token
|
from rest_framework.authtoken.models import Token
|
||||||
|
|
||||||
|
|
||||||
|
@ -293,6 +297,7 @@ class OAuth2Authentication(BaseAuthentication):
|
||||||
"""
|
"""
|
||||||
www_authenticate_realm = 'api'
|
www_authenticate_realm = 'api'
|
||||||
allow_query_params_token = settings.DEBUG
|
allow_query_params_token = settings.DEBUG
|
||||||
|
token_type = getattr(oauth2_constants, 'TOKEN_TYPE', b'Bearer')
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(OAuth2Authentication, self).__init__(*args, **kwargs)
|
super(OAuth2Authentication, self).__init__(*args, **kwargs)
|
||||||
|
@ -310,7 +315,7 @@ class OAuth2Authentication(BaseAuthentication):
|
||||||
|
|
||||||
auth = get_authorization_header(request).split()
|
auth = get_authorization_header(request).split()
|
||||||
|
|
||||||
if auth and auth[0].lower() == b'bearer':
|
if auth and auth[0].lower() == self.token_type.lower():
|
||||||
access_token = auth[1]
|
access_token = auth[1]
|
||||||
elif 'access_token' in request.POST:
|
elif 'access_token' in request.POST:
|
||||||
access_token = request.POST['access_token']
|
access_token = request.POST['access_token']
|
||||||
|
@ -355,4 +360,4 @@ class OAuth2Authentication(BaseAuthentication):
|
||||||
|
|
||||||
Check details on the `OAuth2Authentication.authenticate` method
|
Check details on the `OAuth2Authentication.authenticate` method
|
||||||
"""
|
"""
|
||||||
return 'Bearer realm="%s"' % self.www_authenticate_realm
|
return '%s realm="%s"' % (self.token_type, self.www_authenticate_realm)
|
||||||
|
|
|
@ -57,10 +57,16 @@ urlpatterns = patterns('',
|
||||||
class OAuth2AuthenticationDebug(OAuth2Authentication):
|
class OAuth2AuthenticationDebug(OAuth2Authentication):
|
||||||
allow_query_params_token = True
|
allow_query_params_token = True
|
||||||
|
|
||||||
|
|
||||||
|
class OAuth2AuthenticationCustomTokenType(OAuth2Authentication):
|
||||||
|
token_type = 'Custom'
|
||||||
|
|
||||||
|
|
||||||
if oauth2_provider is not None:
|
if oauth2_provider is not None:
|
||||||
urlpatterns += patterns('',
|
urlpatterns += patterns('',
|
||||||
url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')),
|
url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')),
|
||||||
url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])),
|
url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])),
|
||||||
|
url(r'^oauth2-test-token-type/$', MockView.as_view(authentication_classes=[OAuth2AuthenticationCustomTokenType])),
|
||||||
url(r'^oauth2-test-debug/$', MockView.as_view(authentication_classes=[OAuth2AuthenticationDebug])),
|
url(r'^oauth2-test-debug/$', MockView.as_view(authentication_classes=[OAuth2AuthenticationDebug])),
|
||||||
url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication],
|
url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication],
|
||||||
permission_classes=[permissions.TokenHasReadWriteScope])),
|
permission_classes=[permissions.TokenHasReadWriteScope])),
|
||||||
|
@ -556,6 +562,14 @@ class OAuth2Tests(TestCase):
|
||||||
response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
|
response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
|
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
||||||
|
def test_get_form_passing_authi_with_custom_token_type(self):
|
||||||
|
"""Ensure GETing form over OAuth with correct client credentials
|
||||||
|
and custom TOKEN_TYPE succeed"""
|
||||||
|
auth = "Custom {0}".format(self.access_token.token)
|
||||||
|
response = self.csrf_client.get('/oauth2-test-token-type/', HTTP_AUTHORIZATION=auth)
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
||||||
def test_post_form_passing_auth_url_transport(self):
|
def test_post_form_passing_auth_url_transport(self):
|
||||||
"""Ensure GETing form over OAuth with correct client credentials in form data succeed"""
|
"""Ensure GETing form over OAuth with correct client credentials in form data succeed"""
|
||||||
|
|
Loading…
Reference in New Issue
Block a user