diff --git a/docs/api-guide/authentication.md b/docs/api-guide/authentication.md index 06d043b6c..5568ed0c9 100644 --- a/docs/api-guide/authentication.md +++ b/docs/api-guide/authentication.md @@ -143,6 +143,8 @@ For clients to authenticate, the token key should be included in the `Authorizat Authorization: Token 9944b09199c62bcf9418ad846dd0e4bbdfc6ee4b +**Note:** If you want to use a different keyword in the header, such as `Bearer`, simply subclass `TokenAuthentication` and set the `keyword` class variable. + If successfully authenticated, `TokenAuthentication` provides the following credentials. * `request.user` will be a Django `User` instance. diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index 120be6165..cb9608a3c 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -150,6 +150,7 @@ class TokenAuthentication(BaseAuthentication): Authorization: Token 401f7ac837da42b97f613d789819ff93537bee6a """ + keyword = 'Token' model = None def get_model(self): @@ -168,7 +169,7 @@ class TokenAuthentication(BaseAuthentication): def authenticate(self, request): auth = get_authorization_header(request).split() - if not auth or auth[0].lower() != b'token': + if not auth or auth[0].lower() != self.keyword.lower().encode(): return None if len(auth) == 1: @@ -199,4 +200,4 @@ class TokenAuthentication(BaseAuthentication): return (token.user, token) def authenticate_header(self, request): - return 'Token' + return self.keyword diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 9aff7280b..9784087d8 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -35,6 +35,10 @@ class CustomTokenAuthentication(TokenAuthentication): model = CustomToken +class CustomKeywordTokenAuthentication(TokenAuthentication): + keyword = 'Bearer' + + class MockView(APIView): permission_classes = (permissions.IsAuthenticated,) @@ -53,6 +57,7 @@ urlpatterns = [ url(r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])), url(r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])), url(r'^customtoken/$', MockView.as_view(authentication_classes=[CustomTokenAuthentication])), + url(r'^customkeywordtoken/$', MockView.as_view(authentication_classes=[CustomKeywordTokenAuthentication])), url(r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'), url(r'^auth/', include('rest_framework.urls', namespace='rest_framework')), ] @@ -166,6 +171,7 @@ class BaseTokenAuthTests(object): urls = 'tests.test_authentication' model = None path = None + header_prefix = 'Token ' def setUp(self): self.csrf_client = APIClient(enforce_csrf_checks=True) @@ -179,31 +185,31 @@ class BaseTokenAuthTests(object): def test_post_form_passing_token_auth(self): """Ensure POSTing json over token auth with correct credentials passes and does not require CSRF""" - auth = 'Token ' + self.key + auth = self.header_prefix + self.key response = self.csrf_client.post(self.path, {'example': 'example'}, HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, status.HTTP_200_OK) def test_fail_post_form_passing_nonexistent_token_auth(self): # use a nonexistent token key - auth = 'Token wxyz6789' + auth = self.header_prefix + 'wxyz6789' response = self.csrf_client.post(self.path, {'example': 'example'}, HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) def test_fail_post_form_passing_invalid_token_auth(self): # add an 'invalid' unicode character - auth = 'Token ' + self.key + "¸" + auth = self.header_prefix + self.key + "¸" response = self.csrf_client.post(self.path, {'example': 'example'}, HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) def test_post_json_passing_token_auth(self): """Ensure POSTing form over token auth with correct credentials passes and does not require CSRF""" - auth = "Token " + self.key + auth = self.header_prefix + self.key response = self.csrf_client.post(self.path, {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, status.HTTP_200_OK) def test_post_json_makes_one_db_query(self): """Ensure that authenticating a user using a token performs only one DB query""" - auth = "Token " + self.key + auth = self.header_prefix + self.key def func_to_test(): return self.csrf_client.post(self.path, {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth) @@ -273,6 +279,12 @@ class CustomTokenAuthTests(BaseTokenAuthTests, TestCase): path = '/customtoken/' +class CustomKeywordTokenAuthTests(BaseTokenAuthTests, TestCase): + model = Token + path = '/customkeywordtoken/' + header_prefix = 'Bearer ' + + class IncorrectCredentialsTests(TestCase): def test_incorrect_credentials(self): """