added expired feature for token authentication

This commit is contained in:
Bahattin Cinic 2014-07-08 14:42:22 +03:00
parent 91eabd54bb
commit 232b3ed4d4
5 changed files with 60 additions and 4 deletions

View File

@ -7,6 +7,7 @@ import base64
from django.contrib.auth import authenticate
from django.core.exceptions import ImproperlyConfigured
from django.conf import settings
from rest_framework.settings import api_settings
from rest_framework import exceptions, HTTP_HEADER_ENCODING
from rest_framework.compat import CsrfViewMiddleware
from rest_framework.compat import oauth, oauth_provider, oauth_provider_store
@ -174,6 +175,10 @@ class TokenAuthentication(BaseAuthentication):
if not token.user.is_active:
raise exceptions.AuthenticationFailed('User inactive or deleted')
token_settings = api_settings.DEFAULT_TOKEN_EXPIRE
if token_settings['is_expired'] and token.check_for_expiration():
raise exceptions.AuthenticationFailed('Token has expired')
return (token.user, token)
def authenticate_header(self, request):

View File

@ -1,8 +1,10 @@
import binascii
import os
import datetime
from hashlib import sha1
from django.conf import settings
from django.db import models
from rest_framework.settings import api_settings
# Prior to Django 1.5, the AUTH_USER_MODEL setting does not exist.
@ -36,5 +38,13 @@ class Token(models.Model):
def generate_key(self):
return binascii.hexlify(os.urandom(20)).decode()
def check_for_expiration(self):
token_settings = api_settings.DEFAULT_TOKEN_EXPIRE
if token_settings['is_expired']:
now = datetime.datetime.now()
difference = datetime.timedelta(days=token_settings['expiration_time'])
return self.created < (now - difference)
return False
def __unicode__(self):
return self.key

View File

@ -1,3 +1,4 @@
import datetime
from rest_framework.views import APIView
from rest_framework import status
from rest_framework import parsers
@ -5,12 +6,14 @@ from rest_framework import renderers
from rest_framework.response import Response
from rest_framework.authtoken.models import Token
from rest_framework.authtoken.serializers import AuthTokenSerializer
from rest_framework.settings import api_settings
class ObtainAuthToken(APIView):
throttle_classes = ()
permission_classes = ()
parser_classes = (parsers.FormParser, parsers.MultiPartParser, parsers.JSONParser,)
parser_classes = (parsers.FormParser, parsers.MultiPartParser,
parsers.JSONParser,)
renderer_classes = (renderers.JSONRenderer,)
serializer_class = AuthTokenSerializer
model = Token
@ -18,8 +21,17 @@ class ObtainAuthToken(APIView):
def post(self, request):
serializer = self.serializer_class(data=request.DATA)
if serializer.is_valid():
token, created = Token.objects.get_or_create(user=serializer.object['user'])
return Response({'token': token.key})
token, created = Token.objects.get_or_create(
user=serializer.object['user'])
token_settings = api_settings.DEFAULT_TOKEN_EXPIRE
key = token.key
if not created and token_settings['is_expired']:
# update the created time of the token to keep it valid
key = token.generate_key() if token.check_for_expiration() else key
Token.objects.filter(key=token.key).update(
key=key, created=datetime.datetime.now())
return Response({'token': key})
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)

View File

@ -50,6 +50,11 @@ DEFAULTS = {
),
'DEFAULT_CONTENT_NEGOTIATION_CLASS':
'rest_framework.negotiation.DefaultContentNegotiation',
'DEFAULT_TOKEN_EXPIRE': {
'is_expired': False,
# in days
'expiration_time': 30
},
# Genric view behavior
'DEFAULT_MODEL_SERIALIZER_CLASS':
@ -139,7 +144,8 @@ IMPORT_STRINGS = (
'UNAUTHENTICATED_USER',
'UNAUTHENTICATED_TOKEN',
'VIEW_NAME_FUNCTION',
'VIEW_DESCRIPTION_FUNCTION'
'VIEW_DESCRIPTION_FUNCTION',
'DEFAULT_TOKEN_EXPIRE'
)

View File

@ -24,6 +24,7 @@ from rest_framework.compat import oauth2_provider, oauth2_provider_scope
from rest_framework.compat import oauth, oauth_provider
from rest_framework.test import APIRequestFactory, APIClient
from rest_framework.views import APIView
from rest_framework.settings import api_settings
import base64
import time
import datetime
@ -231,6 +232,28 @@ class TokenAuthTests(TestCase):
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data['token'], self.key)
def test_token_expired(self):
""" Ensure token login view using expired token """
api_settings.DEFAULT_TOKEN_EXPIRE['is_expired'] = True
client = APIClient(enforce_csrf_checks=True)
self.token.created = self.token.created - datetime.timedelta(days=40)
self.token.save()
response = client.post('/token/', {'example': 'example'},
HTTP_AUTHORIZATION='Token %s' % self.token.key,
format='json')
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
def test_token_expire_after_renewal(self):
""" Ensure token renewes on next login after expiration """
api_settings.DEFAULT_TOKEN_EXPIRE['is_expired'] = True
self.token.created = self.token.created - datetime.timedelta(days=40)
self.token.save()
client = APIClient(enforce_csrf_checks=True)
response = client.post('/auth-token/', {'username': self.username,
'password': self.password})
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertNotEqual(response.data['token'], self.key)
class IncorrectCredentialsTests(TestCase):
def test_incorrect_credentials(self):