diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index da9ca510e..d8a962dc9 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -10,7 +10,11 @@ from django.conf import settings from rest_framework import exceptions, HTTP_HEADER_ENCODING from rest_framework.compat import CsrfViewMiddleware from rest_framework.compat import oauth, oauth_provider, oauth_provider_store -from rest_framework.compat import oauth2_provider, provider_now, check_nonce +from rest_framework.compat import (oauth2_provider, + provider_now, + check_nonce, + oauth2_constants, + ) from rest_framework.authtoken.models import Token @@ -293,6 +297,7 @@ class OAuth2Authentication(BaseAuthentication): """ www_authenticate_realm = 'api' allow_query_params_token = settings.DEBUG + token_type = getattr(oauth2_constants, 'TOKEN_TYPE', b'Bearer') def __init__(self, *args, **kwargs): super(OAuth2Authentication, self).__init__(*args, **kwargs) @@ -310,7 +315,7 @@ class OAuth2Authentication(BaseAuthentication): auth = get_authorization_header(request).split() - if auth and auth[0].lower() == b'bearer': + if auth and auth[0].lower() == self.token_type.lower(): access_token = auth[1] elif 'access_token' in request.POST: access_token = request.POST['access_token'] @@ -355,4 +360,4 @@ class OAuth2Authentication(BaseAuthentication): Check details on the `OAuth2Authentication.authenticate` method """ - return 'Bearer realm="%s"' % self.www_authenticate_realm + return '%s realm="%s"' % (self.token_type, self.www_authenticate_realm) diff --git a/rest_framework/tests/test_authentication.py b/rest_framework/tests/test_authentication.py index a1c43d9ce..e0711d844 100644 --- a/rest_framework/tests/test_authentication.py +++ b/rest_framework/tests/test_authentication.py @@ -57,10 +57,16 @@ urlpatterns = patterns('', class OAuth2AuthenticationDebug(OAuth2Authentication): allow_query_params_token = True + +class OAuth2AuthenticationCustomTokenType(OAuth2Authentication): + token_type = 'Custom' + + if oauth2_provider is not None: urlpatterns += patterns('', url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')), url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])), + url(r'^oauth2-test-token-type/$', MockView.as_view(authentication_classes=[OAuth2AuthenticationCustomTokenType])), url(r'^oauth2-test-debug/$', MockView.as_view(authentication_classes=[OAuth2AuthenticationDebug])), url(r'^oauth2-with-scope-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication], permission_classes=[permissions.TokenHasReadWriteScope])), @@ -556,6 +562,14 @@ class OAuth2Tests(TestCase): response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 200) + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') + def test_get_form_passing_authi_with_custom_token_type(self): + """Ensure GETing form over OAuth with correct client credentials + and custom TOKEN_TYPE succeed""" + auth = "Custom {0}".format(self.access_token.token) + response = self.csrf_client.get('/oauth2-test-token-type/', HTTP_AUTHORIZATION=auth) + self.assertEqual(response.status_code, 200) + @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') def test_post_form_passing_auth_url_transport(self): """Ensure GETing form over OAuth with correct client credentials in form data succeed"""