test custom token model

This commit is contained in:
S. Andrew Sheppard 2016-01-05 09:58:16 -06:00
parent 1712c00001
commit 2b8c036b48

View File

@ -9,6 +9,7 @@ from django.contrib.auth.models import User
from django.http import HttpResponse from django.http import HttpResponse
from django.test import TestCase from django.test import TestCase
from django.utils import six from django.utils import six
from django.db import models
from rest_framework import ( from rest_framework import (
HTTP_HEADER_ENCODING, exceptions, permissions, renderers, status HTTP_HEADER_ENCODING, exceptions, permissions, renderers, status
@ -25,6 +26,15 @@ from rest_framework.views import APIView
factory = APIRequestFactory() factory = APIRequestFactory()
class CustomToken(models.Model):
key = models.CharField(max_length=40, primary_key=True)
user = models.OneToOneField(User)
class CustomTokenAuthentication(TokenAuthentication):
model = CustomToken
class MockView(APIView): class MockView(APIView):
permission_classes = (permissions.IsAuthenticated,) permission_classes = (permissions.IsAuthenticated,)
@ -42,6 +52,7 @@ urlpatterns = [
url(r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])), url(r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])),
url(r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])), url(r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])),
url(r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])), url(r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])),
url(r'^customtoken/$', MockView.as_view(authentication_classes=[CustomTokenAuthentication])),
url(r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'), url(r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'),
url(r'^auth/', include('rest_framework.urls', namespace='rest_framework')), url(r'^auth/', include('rest_framework.urls', namespace='rest_framework')),
] ]
@ -142,9 +153,11 @@ class SessionAuthTests(TestCase):
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
class TokenAuthTests(TestCase): class BaseTokenAuthTests(object):
"""Token authentication""" """Token authentication"""
urls = 'tests.test_authentication' urls = 'tests.test_authentication'
model = None
path = None
def setUp(self): def setUp(self):
self.csrf_client = APIClient(enforce_csrf_checks=True) self.csrf_client = APIClient(enforce_csrf_checks=True)
@ -154,30 +167,30 @@ class TokenAuthTests(TestCase):
self.user = User.objects.create_user(self.username, self.email, self.password) self.user = User.objects.create_user(self.username, self.email, self.password)
self.key = 'abcd1234' self.key = 'abcd1234'
self.token = Token.objects.create(key=self.key, user=self.user) self.token = self.model.objects.create(key=self.key, user=self.user)
def test_post_form_passing_token_auth(self): def test_post_form_passing_token_auth(self):
"""Ensure POSTing json over token auth with correct credentials passes and does not require CSRF""" """Ensure POSTing json over token auth with correct credentials passes and does not require CSRF"""
auth = 'Token ' + self.key auth = 'Token ' + self.key
response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) response = self.csrf_client.post(self.path, {'example': 'example'}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_fail_post_form_passing_nonexistent_token_auth(self): def test_fail_post_form_passing_nonexistent_token_auth(self):
# use a nonexistent token key # use a nonexistent token key
auth = 'Token wxyz6789' auth = 'Token wxyz6789'
response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) response = self.csrf_client.post(self.path, {'example': 'example'}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
def test_fail_post_form_passing_invalid_token_auth(self): def test_fail_post_form_passing_invalid_token_auth(self):
# add an 'invalid' unicode character # add an 'invalid' unicode character
auth = 'Token ' + self.key + "¸" auth = 'Token ' + self.key + "¸"
response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth) response = self.csrf_client.post(self.path, {'example': 'example'}, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
def test_post_json_passing_token_auth(self): def test_post_json_passing_token_auth(self):
"""Ensure POSTing form over token auth with correct credentials passes and does not require CSRF""" """Ensure POSTing form over token auth with correct credentials passes and does not require CSRF"""
auth = "Token " + self.key auth = "Token " + self.key
response = self.csrf_client.post('/token/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth) response = self.csrf_client.post(self.path, {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_post_json_makes_one_db_query(self): def test_post_json_makes_one_db_query(self):
@ -185,29 +198,34 @@ class TokenAuthTests(TestCase):
auth = "Token " + self.key auth = "Token " + self.key
def func_to_test(): def func_to_test():
return self.csrf_client.post('/token/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth) return self.csrf_client.post(self.path, {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth)
self.assertNumQueries(1, func_to_test) self.assertNumQueries(1, func_to_test)
def test_post_form_failing_token_auth(self): def test_post_form_failing_token_auth(self):
"""Ensure POSTing form over token auth without correct credentials fails""" """Ensure POSTing form over token auth without correct credentials fails"""
response = self.csrf_client.post('/token/', {'example': 'example'}) response = self.csrf_client.post(self.path, {'example': 'example'})
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
def test_post_json_failing_token_auth(self): def test_post_json_failing_token_auth(self):
"""Ensure POSTing json over token auth without correct credentials fails""" """Ensure POSTing json over token auth without correct credentials fails"""
response = self.csrf_client.post('/token/', {'example': 'example'}, format='json') response = self.csrf_client.post(self.path, {'example': 'example'}, format='json')
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
class TokenAuthTests(BaseTokenAuthTests, TestCase):
model = Token
path = '/token/'
def test_token_has_auto_assigned_key_if_none_provided(self): def test_token_has_auto_assigned_key_if_none_provided(self):
"""Ensure creating a token with no key will auto-assign a key""" """Ensure creating a token with no key will auto-assign a key"""
self.token.delete() self.token.delete()
token = Token.objects.create(user=self.user) token = self.model.objects.create(user=self.user)
self.assertTrue(bool(token.key)) self.assertTrue(bool(token.key))
def test_generate_key_returns_string(self): def test_generate_key_returns_string(self):
"""Ensure generate_key returns a string""" """Ensure generate_key returns a string"""
token = Token() token = self.model()
key = token.generate_key() key = token.generate_key()
self.assertTrue(isinstance(key, six.string_types)) self.assertTrue(isinstance(key, six.string_types))
@ -242,6 +260,11 @@ class TokenAuthTests(TestCase):
self.assertEqual(response.data['token'], self.key) self.assertEqual(response.data['token'], self.key)
class CustomTokenAuthTests(BaseTokenAuthTests, TestCase):
model = CustomToken
path = '/customtoken/'
class IncorrectCredentialsTests(TestCase): class IncorrectCredentialsTests(TestCase):
def test_incorrect_credentials(self): def test_incorrect_credentials(self):
""" """