mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-01-26 09:14:34 +03:00
Merge pull request #3785 from sheppard/authtoken-import
don't import authtoken model until needed
This commit is contained in:
commit
37f7b76f72
|
@ -10,7 +10,6 @@ from django.middleware.csrf import CsrfViewMiddleware
|
||||||
from django.utils.translation import ugettext_lazy as _
|
from django.utils.translation import ugettext_lazy as _
|
||||||
|
|
||||||
from rest_framework import HTTP_HEADER_ENCODING, exceptions
|
from rest_framework import HTTP_HEADER_ENCODING, exceptions
|
||||||
from rest_framework.authtoken.models import Token
|
|
||||||
|
|
||||||
|
|
||||||
def get_authorization_header(request):
|
def get_authorization_header(request):
|
||||||
|
@ -149,7 +148,14 @@ class TokenAuthentication(BaseAuthentication):
|
||||||
Authorization: Token 401f7ac837da42b97f613d789819ff93537bee6a
|
Authorization: Token 401f7ac837da42b97f613d789819ff93537bee6a
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model = Token
|
model = None
|
||||||
|
|
||||||
|
def get_model(self):
|
||||||
|
if self.model is not None:
|
||||||
|
return self.model
|
||||||
|
from rest_framework.authtoken.models import Token
|
||||||
|
return Token
|
||||||
|
|
||||||
"""
|
"""
|
||||||
A custom token model may be used, but must have the following properties.
|
A custom token model may be used, but must have the following properties.
|
||||||
|
|
||||||
|
@ -179,9 +185,10 @@ class TokenAuthentication(BaseAuthentication):
|
||||||
return self.authenticate_credentials(token)
|
return self.authenticate_credentials(token)
|
||||||
|
|
||||||
def authenticate_credentials(self, key):
|
def authenticate_credentials(self, key):
|
||||||
|
model = self.get_model()
|
||||||
try:
|
try:
|
||||||
token = self.model.objects.select_related('user').get(key=key)
|
token = model.objects.select_related('user').get(key=key)
|
||||||
except self.model.DoesNotExist:
|
except model.DoesNotExist:
|
||||||
raise exceptions.AuthenticationFailed(_('Invalid token.'))
|
raise exceptions.AuthenticationFailed(_('Invalid token.'))
|
||||||
|
|
||||||
if not token.user.is_active:
|
if not token.user.is_active:
|
||||||
|
|
|
@ -21,14 +21,6 @@ class Token(models.Model):
|
||||||
user = models.OneToOneField(AUTH_USER_MODEL, related_name='auth_token')
|
user = models.OneToOneField(AUTH_USER_MODEL, related_name='auth_token')
|
||||||
created = models.DateTimeField(auto_now_add=True)
|
created = models.DateTimeField(auto_now_add=True)
|
||||||
|
|
||||||
class Meta:
|
|
||||||
# Work around for a bug in Django:
|
|
||||||
# https://code.djangoproject.com/ticket/19422
|
|
||||||
#
|
|
||||||
# Also see corresponding ticket:
|
|
||||||
# https://github.com/tomchristie/django-rest-framework/issues/705
|
|
||||||
abstract = 'rest_framework.authtoken' not in settings.INSTALLED_APPS
|
|
||||||
|
|
||||||
def save(self, *args, **kwargs):
|
def save(self, *args, **kwargs):
|
||||||
if not self.key:
|
if not self.key:
|
||||||
self.key = self.generate_key()
|
self.key = self.generate_key()
|
||||||
|
|
|
@ -6,6 +6,7 @@ import base64
|
||||||
|
|
||||||
from django.conf.urls import include, url
|
from django.conf.urls import include, url
|
||||||
from django.contrib.auth.models import User
|
from django.contrib.auth.models import User
|
||||||
|
from django.db import models
|
||||||
from django.http import HttpResponse
|
from django.http import HttpResponse
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
from django.utils import six
|
from django.utils import six
|
||||||
|
@ -25,6 +26,15 @@ from rest_framework.views import APIView
|
||||||
factory = APIRequestFactory()
|
factory = APIRequestFactory()
|
||||||
|
|
||||||
|
|
||||||
|
class CustomToken(models.Model):
|
||||||
|
key = models.CharField(max_length=40, primary_key=True)
|
||||||
|
user = models.OneToOneField(User)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomTokenAuthentication(TokenAuthentication):
|
||||||
|
model = CustomToken
|
||||||
|
|
||||||
|
|
||||||
class MockView(APIView):
|
class MockView(APIView):
|
||||||
permission_classes = (permissions.IsAuthenticated,)
|
permission_classes = (permissions.IsAuthenticated,)
|
||||||
|
|
||||||
|
@ -42,6 +52,7 @@ urlpatterns = [
|
||||||
url(r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])),
|
url(r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])),
|
||||||
url(r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])),
|
url(r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])),
|
||||||
url(r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])),
|
url(r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])),
|
||||||
|
url(r'^customtoken/$', MockView.as_view(authentication_classes=[CustomTokenAuthentication])),
|
||||||
url(r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'),
|
url(r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'),
|
||||||
url(r'^auth/', include('rest_framework.urls', namespace='rest_framework')),
|
url(r'^auth/', include('rest_framework.urls', namespace='rest_framework')),
|
||||||
]
|
]
|
||||||
|
@ -142,9 +153,11 @@ class SessionAuthTests(TestCase):
|
||||||
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
||||||
|
|
||||||
|
|
||||||
class TokenAuthTests(TestCase):
|
class BaseTokenAuthTests(object):
|
||||||
"""Token authentication"""
|
"""Token authentication"""
|
||||||
urls = 'tests.test_authentication'
|
urls = 'tests.test_authentication'
|
||||||
|
model = None
|
||||||
|
path = None
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.csrf_client = APIClient(enforce_csrf_checks=True)
|
self.csrf_client = APIClient(enforce_csrf_checks=True)
|
||||||
|
@ -154,24 +167,30 @@ class TokenAuthTests(TestCase):
|
||||||
self.user = User.objects.create_user(self.username, self.email, self.password)
|
self.user = User.objects.create_user(self.username, self.email, self.password)
|
||||||
|
|
||||||
self.key = 'abcd1234'
|
self.key = 'abcd1234'
|
||||||
self.token = Token.objects.create(key=self.key, user=self.user)
|
self.token = self.model.objects.create(key=self.key, user=self.user)
|
||||||
|
|
||||||
def test_post_form_passing_token_auth(self):
|
def test_post_form_passing_token_auth(self):
|
||||||
"""Ensure POSTing json over token auth with correct credentials passes and does not require CSRF"""
|
"""Ensure POSTing json over token auth with correct credentials passes and does not require CSRF"""
|
||||||
auth = 'Token ' + self.key
|
auth = 'Token ' + self.key
|
||||||
response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
|
response = self.csrf_client.post(self.path, {'example': 'example'}, HTTP_AUTHORIZATION=auth)
|
||||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
|
||||||
|
def test_fail_post_form_passing_nonexistent_token_auth(self):
|
||||||
|
# use a nonexistent token key
|
||||||
|
auth = 'Token wxyz6789'
|
||||||
|
response = self.csrf_client.post(self.path, {'example': 'example'}, HTTP_AUTHORIZATION=auth)
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
||||||
|
|
||||||
def test_fail_post_form_passing_invalid_token_auth(self):
|
def test_fail_post_form_passing_invalid_token_auth(self):
|
||||||
# add an 'invalid' unicode character
|
# add an 'invalid' unicode character
|
||||||
auth = 'Token ' + self.key + "¸"
|
auth = 'Token ' + self.key + "¸"
|
||||||
response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
|
response = self.csrf_client.post(self.path, {'example': 'example'}, HTTP_AUTHORIZATION=auth)
|
||||||
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
||||||
|
|
||||||
def test_post_json_passing_token_auth(self):
|
def test_post_json_passing_token_auth(self):
|
||||||
"""Ensure POSTing form over token auth with correct credentials passes and does not require CSRF"""
|
"""Ensure POSTing form over token auth with correct credentials passes and does not require CSRF"""
|
||||||
auth = "Token " + self.key
|
auth = "Token " + self.key
|
||||||
response = self.csrf_client.post('/token/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth)
|
response = self.csrf_client.post(self.path, {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth)
|
||||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
|
||||||
def test_post_json_makes_one_db_query(self):
|
def test_post_json_makes_one_db_query(self):
|
||||||
|
@ -179,29 +198,34 @@ class TokenAuthTests(TestCase):
|
||||||
auth = "Token " + self.key
|
auth = "Token " + self.key
|
||||||
|
|
||||||
def func_to_test():
|
def func_to_test():
|
||||||
return self.csrf_client.post('/token/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth)
|
return self.csrf_client.post(self.path, {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth)
|
||||||
|
|
||||||
self.assertNumQueries(1, func_to_test)
|
self.assertNumQueries(1, func_to_test)
|
||||||
|
|
||||||
def test_post_form_failing_token_auth(self):
|
def test_post_form_failing_token_auth(self):
|
||||||
"""Ensure POSTing form over token auth without correct credentials fails"""
|
"""Ensure POSTing form over token auth without correct credentials fails"""
|
||||||
response = self.csrf_client.post('/token/', {'example': 'example'})
|
response = self.csrf_client.post(self.path, {'example': 'example'})
|
||||||
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
||||||
|
|
||||||
def test_post_json_failing_token_auth(self):
|
def test_post_json_failing_token_auth(self):
|
||||||
"""Ensure POSTing json over token auth without correct credentials fails"""
|
"""Ensure POSTing json over token auth without correct credentials fails"""
|
||||||
response = self.csrf_client.post('/token/', {'example': 'example'}, format='json')
|
response = self.csrf_client.post(self.path, {'example': 'example'}, format='json')
|
||||||
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
|
||||||
|
|
||||||
|
|
||||||
|
class TokenAuthTests(BaseTokenAuthTests, TestCase):
|
||||||
|
model = Token
|
||||||
|
path = '/token/'
|
||||||
|
|
||||||
def test_token_has_auto_assigned_key_if_none_provided(self):
|
def test_token_has_auto_assigned_key_if_none_provided(self):
|
||||||
"""Ensure creating a token with no key will auto-assign a key"""
|
"""Ensure creating a token with no key will auto-assign a key"""
|
||||||
self.token.delete()
|
self.token.delete()
|
||||||
token = Token.objects.create(user=self.user)
|
token = self.model.objects.create(user=self.user)
|
||||||
self.assertTrue(bool(token.key))
|
self.assertTrue(bool(token.key))
|
||||||
|
|
||||||
def test_generate_key_returns_string(self):
|
def test_generate_key_returns_string(self):
|
||||||
"""Ensure generate_key returns a string"""
|
"""Ensure generate_key returns a string"""
|
||||||
token = Token()
|
token = self.model()
|
||||||
key = token.generate_key()
|
key = token.generate_key()
|
||||||
self.assertTrue(isinstance(key, six.string_types))
|
self.assertTrue(isinstance(key, six.string_types))
|
||||||
|
|
||||||
|
@ -236,6 +260,11 @@ class TokenAuthTests(TestCase):
|
||||||
self.assertEqual(response.data['token'], self.key)
|
self.assertEqual(response.data['token'], self.key)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomTokenAuthTests(BaseTokenAuthTests, TestCase):
|
||||||
|
model = CustomToken
|
||||||
|
path = '/customtoken/'
|
||||||
|
|
||||||
|
|
||||||
class IncorrectCredentialsTests(TestCase):
|
class IncorrectCredentialsTests(TestCase):
|
||||||
def test_incorrect_credentials(self):
|
def test_incorrect_credentials(self):
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue
Block a user