Merge pull request #3785 from sheppard/authtoken-import

don't import authtoken model until needed
This commit is contained in:
Tom Christie 2016-01-05 17:28:48 +00:00
commit 37f7b76f72
3 changed files with 50 additions and 22 deletions

View File

@ -10,7 +10,6 @@ from django.middleware.csrf import CsrfViewMiddleware
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework import HTTP_HEADER_ENCODING, exceptions from rest_framework import HTTP_HEADER_ENCODING, exceptions
from rest_framework.authtoken.models import Token
def get_authorization_header(request): def get_authorization_header(request):
@ -149,7 +148,14 @@ class TokenAuthentication(BaseAuthentication):
Authorization: Token 401f7ac837da42b97f613d789819ff93537bee6a Authorization: Token 401f7ac837da42b97f613d789819ff93537bee6a
""" """
model = Token model = None
def get_model(self):
if self.model is not None:
return self.model
from rest_framework.authtoken.models import Token
return Token
""" """
A custom token model may be used, but must have the following properties. A custom token model may be used, but must have the following properties.
@ -179,9 +185,10 @@ class TokenAuthentication(BaseAuthentication):
return self.authenticate_credentials(token) return self.authenticate_credentials(token)
def authenticate_credentials(self, key): def authenticate_credentials(self, key):
model = self.get_model()
try: try:
token = self.model.objects.select_related('user').get(key=key) token = model.objects.select_related('user').get(key=key)
except self.model.DoesNotExist: except model.DoesNotExist:
raise exceptions.AuthenticationFailed(_('Invalid token.')) raise exceptions.AuthenticationFailed(_('Invalid token.'))
if not token.user.is_active: if not token.user.is_active:

View File

@ -21,14 +21,6 @@ class Token(models.Model):
user = models.OneToOneField(AUTH_USER_MODEL, related_name='auth_token') user = models.OneToOneField(AUTH_USER_MODEL, related_name='auth_token')
created = models.DateTimeField(auto_now_add=True) created = models.DateTimeField(auto_now_add=True)
class Meta:
# Work around for a bug in Django:
# https://code.djangoproject.com/ticket/19422
#
# Also see corresponding ticket:
# https://github.com/tomchristie/django-rest-framework/issues/705
abstract = 'rest_framework.authtoken' not in settings.INSTALLED_APPS
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
if not self.key: if not self.key:
self.key = self.generate_key() self.key = self.generate_key()

View File

@ -6,6 +6,7 @@ import base64
from django.conf.urls import include, url from django.conf.urls import include, url
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.db import models
from django.http import HttpResponse from django.http import HttpResponse
from django.test import TestCase from django.test import TestCase
from django.utils import six from django.utils import six
@ -25,6 +26,15 @@ from rest_framework.views import APIView
factory = APIRequestFactory() factory = APIRequestFactory()
class CustomToken(models.Model):
key = models.CharField(max_length=40, primary_key=True)
user = models.OneToOneField(User)
class CustomTokenAuthentication(TokenAuthentication):
model = CustomToken
class MockView(APIView): class MockView(APIView):
permission_classes = (permissions.IsAuthenticated,) permission_classes = (permissions.IsAuthenticated,)
@ -42,6 +52,7 @@ urlpatterns = [
url(r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])), url(r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])),
url(r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])), url(r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])),
url(r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])), url(r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])),
url(r'^customtoken/$', MockView.as_view(authentication_classes=[CustomTokenAuthentication])),
url(r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'), url(r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'),
url(r'^auth/', include('rest_framework.urls', namespace='rest_framework')), url(r'^auth/', include('rest_framework.urls', namespace='rest_framework')),
] ]
@ -142,9 +153,11 @@ class SessionAuthTests(TestCase):
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
class TokenAuthTests(TestCase): class BaseTokenAuthTests(object):
"""Token authentication""" """Token authentication"""
urls = 'tests.test_authentication' urls = 'tests.test_authentication'
model = None
path = None
def setUp(self): def setUp(self):
self.csrf_client = APIClient(enforce_csrf_checks=True) self.csrf_client = APIClient(enforce_csrf_checks=True)
@ -154,24 +167,30 @@ class TokenAuthTests(TestCase):
self.user = User.objects.create_user(self.username, self.email, self.password) self.user = User.objects.create_user(self.username, self.email, self.password)
self.key = 'abcd1234' self.key = 'abcd1234'
self.token = Token.objects.create(key=self.key, user=self.user) self.token = self.model.objects.create(key=self.key, user=self.user)
def test_post_form_passing_token_auth(self): def test_post_form_passing_token_auth(self):
"""Ensure POSTing json over token auth with correct credentials passes and does not require CSRF""" """Ensure POSTing json over token auth with correct credentials passes and does not require CSRF"""
auth = 'Token ' + self.key auth = 'Token ' + self.key
response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) response = self.csrf_client.post(self.path, {'example': 'example'}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_fail_post_form_passing_nonexistent_token_auth(self):
# use a nonexistent token key
auth = 'Token wxyz6789'
response = self.csrf_client.post(self.path, {'example': 'example'}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
def test_fail_post_form_passing_invalid_token_auth(self): def test_fail_post_form_passing_invalid_token_auth(self):
# add an 'invalid' unicode character # add an 'invalid' unicode character
auth = 'Token ' + self.key + "¸" auth = 'Token ' + self.key + "¸"
response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) response = self.csrf_client.post(self.path, {'example': 'example'}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
def test_post_json_passing_token_auth(self): def test_post_json_passing_token_auth(self):
"""Ensure POSTing form over token auth with correct credentials passes and does not require CSRF""" """Ensure POSTing form over token auth with correct credentials passes and does not require CSRF"""
auth = "Token " + self.key auth = "Token " + self.key
response = self.csrf_client.post('/token/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth) response = self.csrf_client.post(self.path, {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_post_json_makes_one_db_query(self): def test_post_json_makes_one_db_query(self):
@ -179,29 +198,34 @@ class TokenAuthTests(TestCase):
auth = "Token " + self.key auth = "Token " + self.key
def func_to_test(): def func_to_test():
return self.csrf_client.post('/token/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth) return self.csrf_client.post(self.path, {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth)
self.assertNumQueries(1, func_to_test) self.assertNumQueries(1, func_to_test)
def test_post_form_failing_token_auth(self): def test_post_form_failing_token_auth(self):
"""Ensure POSTing form over token auth without correct credentials fails""" """Ensure POSTing form over token auth without correct credentials fails"""
response = self.csrf_client.post('/token/', {'example': 'example'}) response = self.csrf_client.post(self.path, {'example': 'example'})
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
def test_post_json_failing_token_auth(self): def test_post_json_failing_token_auth(self):
"""Ensure POSTing json over token auth without correct credentials fails""" """Ensure POSTing json over token auth without correct credentials fails"""
response = self.csrf_client.post('/token/', {'example': 'example'}, format='json') response = self.csrf_client.post(self.path, {'example': 'example'}, format='json')
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
class TokenAuthTests(BaseTokenAuthTests, TestCase):
model = Token
path = '/token/'
def test_token_has_auto_assigned_key_if_none_provided(self): def test_token_has_auto_assigned_key_if_none_provided(self):
"""Ensure creating a token with no key will auto-assign a key""" """Ensure creating a token with no key will auto-assign a key"""
self.token.delete() self.token.delete()
token = Token.objects.create(user=self.user) token = self.model.objects.create(user=self.user)
self.assertTrue(bool(token.key)) self.assertTrue(bool(token.key))
def test_generate_key_returns_string(self): def test_generate_key_returns_string(self):
"""Ensure generate_key returns a string""" """Ensure generate_key returns a string"""
token = Token() token = self.model()
key = token.generate_key() key = token.generate_key()
self.assertTrue(isinstance(key, six.string_types)) self.assertTrue(isinstance(key, six.string_types))
@ -236,6 +260,11 @@ class TokenAuthTests(TestCase):
self.assertEqual(response.data['token'], self.key) self.assertEqual(response.data['token'], self.key)
class CustomTokenAuthTests(BaseTokenAuthTests, TestCase):
model = CustomToken
path = '/customtoken/'
class IncorrectCredentialsTests(TestCase): class IncorrectCredentialsTests(TestCase):
def test_incorrect_credentials(self): def test_incorrect_credentials(self):
""" """