mirror of
				https://github.com/encode/django-rest-framework.git
				synced 2025-10-31 07:57:55 +03:00 
			
		
		
		
	Merge pull request #3785 from sheppard/authtoken-import
don't import authtoken model until needed
This commit is contained in:
		
						commit
						37f7b76f72
					
				|  | @ -10,7 +10,6 @@ from django.middleware.csrf import CsrfViewMiddleware | ||||||
| from django.utils.translation import ugettext_lazy as _ | from django.utils.translation import ugettext_lazy as _ | ||||||
| 
 | 
 | ||||||
| from rest_framework import HTTP_HEADER_ENCODING, exceptions | from rest_framework import HTTP_HEADER_ENCODING, exceptions | ||||||
| from rest_framework.authtoken.models import Token |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def get_authorization_header(request): | def get_authorization_header(request): | ||||||
|  | @ -149,7 +148,14 @@ class TokenAuthentication(BaseAuthentication): | ||||||
|         Authorization: Token 401f7ac837da42b97f613d789819ff93537bee6a |         Authorization: Token 401f7ac837da42b97f613d789819ff93537bee6a | ||||||
|     """ |     """ | ||||||
| 
 | 
 | ||||||
|     model = Token |     model = None | ||||||
|  | 
 | ||||||
|  |     def get_model(self): | ||||||
|  |         if self.model is not None: | ||||||
|  |             return self.model | ||||||
|  |         from rest_framework.authtoken.models import Token | ||||||
|  |         return Token | ||||||
|  | 
 | ||||||
|     """ |     """ | ||||||
|     A custom token model may be used, but must have the following properties. |     A custom token model may be used, but must have the following properties. | ||||||
| 
 | 
 | ||||||
|  | @ -179,9 +185,10 @@ class TokenAuthentication(BaseAuthentication): | ||||||
|         return self.authenticate_credentials(token) |         return self.authenticate_credentials(token) | ||||||
| 
 | 
 | ||||||
|     def authenticate_credentials(self, key): |     def authenticate_credentials(self, key): | ||||||
|  |         model = self.get_model() | ||||||
|         try: |         try: | ||||||
|             token = self.model.objects.select_related('user').get(key=key) |             token = model.objects.select_related('user').get(key=key) | ||||||
|         except self.model.DoesNotExist: |         except model.DoesNotExist: | ||||||
|             raise exceptions.AuthenticationFailed(_('Invalid token.')) |             raise exceptions.AuthenticationFailed(_('Invalid token.')) | ||||||
| 
 | 
 | ||||||
|         if not token.user.is_active: |         if not token.user.is_active: | ||||||
|  |  | ||||||
|  | @ -21,14 +21,6 @@ class Token(models.Model): | ||||||
|     user = models.OneToOneField(AUTH_USER_MODEL, related_name='auth_token') |     user = models.OneToOneField(AUTH_USER_MODEL, related_name='auth_token') | ||||||
|     created = models.DateTimeField(auto_now_add=True) |     created = models.DateTimeField(auto_now_add=True) | ||||||
| 
 | 
 | ||||||
|     class Meta: |  | ||||||
|         # Work around for a bug in Django: |  | ||||||
|         # https://code.djangoproject.com/ticket/19422 |  | ||||||
|         # |  | ||||||
|         # Also see corresponding ticket: |  | ||||||
|         # https://github.com/tomchristie/django-rest-framework/issues/705 |  | ||||||
|         abstract = 'rest_framework.authtoken' not in settings.INSTALLED_APPS |  | ||||||
| 
 |  | ||||||
|     def save(self, *args, **kwargs): |     def save(self, *args, **kwargs): | ||||||
|         if not self.key: |         if not self.key: | ||||||
|             self.key = self.generate_key() |             self.key = self.generate_key() | ||||||
|  |  | ||||||
|  | @ -6,6 +6,7 @@ import base64 | ||||||
| 
 | 
 | ||||||
| from django.conf.urls import include, url | from django.conf.urls import include, url | ||||||
| from django.contrib.auth.models import User | from django.contrib.auth.models import User | ||||||
|  | from django.db import models | ||||||
| from django.http import HttpResponse | from django.http import HttpResponse | ||||||
| from django.test import TestCase | from django.test import TestCase | ||||||
| from django.utils import six | from django.utils import six | ||||||
|  | @ -25,6 +26,15 @@ from rest_framework.views import APIView | ||||||
| factory = APIRequestFactory() | factory = APIRequestFactory() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | class CustomToken(models.Model): | ||||||
|  |     key = models.CharField(max_length=40, primary_key=True) | ||||||
|  |     user = models.OneToOneField(User) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class CustomTokenAuthentication(TokenAuthentication): | ||||||
|  |     model = CustomToken | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class MockView(APIView): | class MockView(APIView): | ||||||
|     permission_classes = (permissions.IsAuthenticated,) |     permission_classes = (permissions.IsAuthenticated,) | ||||||
| 
 | 
 | ||||||
|  | @ -42,6 +52,7 @@ urlpatterns = [ | ||||||
|     url(r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])), |     url(r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])), | ||||||
|     url(r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])), |     url(r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])), | ||||||
|     url(r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])), |     url(r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])), | ||||||
|  |     url(r'^customtoken/$', MockView.as_view(authentication_classes=[CustomTokenAuthentication])), | ||||||
|     url(r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'), |     url(r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'), | ||||||
|     url(r'^auth/', include('rest_framework.urls', namespace='rest_framework')), |     url(r'^auth/', include('rest_framework.urls', namespace='rest_framework')), | ||||||
| ] | ] | ||||||
|  | @ -142,9 +153,11 @@ class SessionAuthTests(TestCase): | ||||||
|         self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) |         self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class TokenAuthTests(TestCase): | class BaseTokenAuthTests(object): | ||||||
|     """Token authentication""" |     """Token authentication""" | ||||||
|     urls = 'tests.test_authentication' |     urls = 'tests.test_authentication' | ||||||
|  |     model = None | ||||||
|  |     path = None | ||||||
| 
 | 
 | ||||||
|     def setUp(self): |     def setUp(self): | ||||||
|         self.csrf_client = APIClient(enforce_csrf_checks=True) |         self.csrf_client = APIClient(enforce_csrf_checks=True) | ||||||
|  | @ -154,24 +167,30 @@ class TokenAuthTests(TestCase): | ||||||
|         self.user = User.objects.create_user(self.username, self.email, self.password) |         self.user = User.objects.create_user(self.username, self.email, self.password) | ||||||
| 
 | 
 | ||||||
|         self.key = 'abcd1234' |         self.key = 'abcd1234' | ||||||
|         self.token = Token.objects.create(key=self.key, user=self.user) |         self.token = self.model.objects.create(key=self.key, user=self.user) | ||||||
| 
 | 
 | ||||||
|     def test_post_form_passing_token_auth(self): |     def test_post_form_passing_token_auth(self): | ||||||
|         """Ensure POSTing json over token auth with correct credentials passes and does not require CSRF""" |         """Ensure POSTing json over token auth with correct credentials passes and does not require CSRF""" | ||||||
|         auth = 'Token ' + self.key |         auth = 'Token ' + self.key | ||||||
|         response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) |         response = self.csrf_client.post(self.path, {'example': 'example'}, HTTP_AUTHORIZATION=auth) | ||||||
|         self.assertEqual(response.status_code, status.HTTP_200_OK) |         self.assertEqual(response.status_code, status.HTTP_200_OK) | ||||||
| 
 | 
 | ||||||
|  |     def test_fail_post_form_passing_nonexistent_token_auth(self): | ||||||
|  |         # use a nonexistent token key | ||||||
|  |         auth = 'Token wxyz6789' | ||||||
|  |         response = self.csrf_client.post(self.path, {'example': 'example'}, HTTP_AUTHORIZATION=auth) | ||||||
|  |         self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) | ||||||
|  | 
 | ||||||
|     def test_fail_post_form_passing_invalid_token_auth(self): |     def test_fail_post_form_passing_invalid_token_auth(self): | ||||||
|         # add an 'invalid' unicode character |         # add an 'invalid' unicode character | ||||||
|         auth = 'Token ' + self.key + "¸" |         auth = 'Token ' + self.key + "¸" | ||||||
|         response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) |         response = self.csrf_client.post(self.path, {'example': 'example'}, HTTP_AUTHORIZATION=auth) | ||||||
|         self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) |         self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) | ||||||
| 
 | 
 | ||||||
|     def test_post_json_passing_token_auth(self): |     def test_post_json_passing_token_auth(self): | ||||||
|         """Ensure POSTing form over token auth with correct credentials passes and does not require CSRF""" |         """Ensure POSTing form over token auth with correct credentials passes and does not require CSRF""" | ||||||
|         auth = "Token " + self.key |         auth = "Token " + self.key | ||||||
|         response = self.csrf_client.post('/token/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth) |         response = self.csrf_client.post(self.path, {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth) | ||||||
|         self.assertEqual(response.status_code, status.HTTP_200_OK) |         self.assertEqual(response.status_code, status.HTTP_200_OK) | ||||||
| 
 | 
 | ||||||
|     def test_post_json_makes_one_db_query(self): |     def test_post_json_makes_one_db_query(self): | ||||||
|  | @ -179,29 +198,34 @@ class TokenAuthTests(TestCase): | ||||||
|         auth = "Token " + self.key |         auth = "Token " + self.key | ||||||
| 
 | 
 | ||||||
|         def func_to_test(): |         def func_to_test(): | ||||||
|             return self.csrf_client.post('/token/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth) |             return self.csrf_client.post(self.path, {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth) | ||||||
| 
 | 
 | ||||||
|         self.assertNumQueries(1, func_to_test) |         self.assertNumQueries(1, func_to_test) | ||||||
| 
 | 
 | ||||||
|     def test_post_form_failing_token_auth(self): |     def test_post_form_failing_token_auth(self): | ||||||
|         """Ensure POSTing form over token auth without correct credentials fails""" |         """Ensure POSTing form over token auth without correct credentials fails""" | ||||||
|         response = self.csrf_client.post('/token/', {'example': 'example'}) |         response = self.csrf_client.post(self.path, {'example': 'example'}) | ||||||
|         self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) |         self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) | ||||||
| 
 | 
 | ||||||
|     def test_post_json_failing_token_auth(self): |     def test_post_json_failing_token_auth(self): | ||||||
|         """Ensure POSTing json over token auth without correct credentials fails""" |         """Ensure POSTing json over token auth without correct credentials fails""" | ||||||
|         response = self.csrf_client.post('/token/', {'example': 'example'}, format='json') |         response = self.csrf_client.post(self.path, {'example': 'example'}, format='json') | ||||||
|         self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) |         self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
|  | class TokenAuthTests(BaseTokenAuthTests, TestCase): | ||||||
|  |     model = Token | ||||||
|  |     path = '/token/' | ||||||
|  | 
 | ||||||
|     def test_token_has_auto_assigned_key_if_none_provided(self): |     def test_token_has_auto_assigned_key_if_none_provided(self): | ||||||
|         """Ensure creating a token with no key will auto-assign a key""" |         """Ensure creating a token with no key will auto-assign a key""" | ||||||
|         self.token.delete() |         self.token.delete() | ||||||
|         token = Token.objects.create(user=self.user) |         token = self.model.objects.create(user=self.user) | ||||||
|         self.assertTrue(bool(token.key)) |         self.assertTrue(bool(token.key)) | ||||||
| 
 | 
 | ||||||
|     def test_generate_key_returns_string(self): |     def test_generate_key_returns_string(self): | ||||||
|         """Ensure generate_key returns a string""" |         """Ensure generate_key returns a string""" | ||||||
|         token = Token() |         token = self.model() | ||||||
|         key = token.generate_key() |         key = token.generate_key() | ||||||
|         self.assertTrue(isinstance(key, six.string_types)) |         self.assertTrue(isinstance(key, six.string_types)) | ||||||
| 
 | 
 | ||||||
|  | @ -236,6 +260,11 @@ class TokenAuthTests(TestCase): | ||||||
|         self.assertEqual(response.data['token'], self.key) |         self.assertEqual(response.data['token'], self.key) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | class CustomTokenAuthTests(BaseTokenAuthTests, TestCase): | ||||||
|  |     model = CustomToken | ||||||
|  |     path = '/customtoken/' | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class IncorrectCredentialsTests(TestCase): | class IncorrectCredentialsTests(TestCase): | ||||||
|     def test_incorrect_credentials(self): |     def test_incorrect_credentials(self): | ||||||
|         """ |         """ | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user