From 232b3ed4d4fdb613b5058afbf9dbaa962092d734 Mon Sep 17 00:00:00 2001 From: Bahattin Cinic Date: Tue, 8 Jul 2014 14:42:22 +0300 Subject: [PATCH] added expired feature for token authentication --- rest_framework/authentication.py | 5 +++++ rest_framework/authtoken/models.py | 10 +++++++++ rest_framework/authtoken/views.py | 18 +++++++++++++--- rest_framework/settings.py | 8 ++++++- rest_framework/tests/test_authentication.py | 23 +++++++++++++++++++++ 5 files changed, 60 insertions(+), 4 deletions(-) diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index da9ca510e..c3bc84c6d 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -7,6 +7,7 @@ import base64 from django.contrib.auth import authenticate from django.core.exceptions import ImproperlyConfigured from django.conf import settings +from rest_framework.settings import api_settings from rest_framework import exceptions, HTTP_HEADER_ENCODING from rest_framework.compat import CsrfViewMiddleware from rest_framework.compat import oauth, oauth_provider, oauth_provider_store @@ -174,6 +175,10 @@ class TokenAuthentication(BaseAuthentication): if not token.user.is_active: raise exceptions.AuthenticationFailed('User inactive or deleted') + token_settings = api_settings.DEFAULT_TOKEN_EXPIRE + if token_settings['is_expired'] and token.check_for_expiration(): + raise exceptions.AuthenticationFailed('Token has expired') + return (token.user, token) def authenticate_header(self, request): diff --git a/rest_framework/authtoken/models.py b/rest_framework/authtoken/models.py index 167fa5314..8435b17d9 100644 --- a/rest_framework/authtoken/models.py +++ b/rest_framework/authtoken/models.py @@ -1,8 +1,10 @@ import binascii import os +import datetime from hashlib import sha1 from django.conf import settings from django.db import models +from rest_framework.settings import api_settings # Prior to Django 1.5, the AUTH_USER_MODEL setting does not exist. @@ -36,5 +38,13 @@ class Token(models.Model): def generate_key(self): return binascii.hexlify(os.urandom(20)).decode() + def check_for_expiration(self): + token_settings = api_settings.DEFAULT_TOKEN_EXPIRE + if token_settings['is_expired']: + now = datetime.datetime.now() + difference = datetime.timedelta(days=token_settings['expiration_time']) + return self.created < (now - difference) + return False + def __unicode__(self): return self.key diff --git a/rest_framework/authtoken/views.py b/rest_framework/authtoken/views.py index 7c03cb766..33652879f 100644 --- a/rest_framework/authtoken/views.py +++ b/rest_framework/authtoken/views.py @@ -1,3 +1,4 @@ +import datetime from rest_framework.views import APIView from rest_framework import status from rest_framework import parsers @@ -5,12 +6,14 @@ from rest_framework import renderers from rest_framework.response import Response from rest_framework.authtoken.models import Token from rest_framework.authtoken.serializers import AuthTokenSerializer +from rest_framework.settings import api_settings class ObtainAuthToken(APIView): throttle_classes = () permission_classes = () - parser_classes = (parsers.FormParser, parsers.MultiPartParser, parsers.JSONParser,) + parser_classes = (parsers.FormParser, parsers.MultiPartParser, + parsers.JSONParser,) renderer_classes = (renderers.JSONRenderer,) serializer_class = AuthTokenSerializer model = Token @@ -18,8 +21,17 @@ class ObtainAuthToken(APIView): def post(self, request): serializer = self.serializer_class(data=request.DATA) if serializer.is_valid(): - token, created = Token.objects.get_or_create(user=serializer.object['user']) - return Response({'token': token.key}) + token, created = Token.objects.get_or_create( + user=serializer.object['user']) + + token_settings = api_settings.DEFAULT_TOKEN_EXPIRE + key = token.key + if not created and token_settings['is_expired']: + # update the created time of the token to keep it valid + key = token.generate_key() if token.check_for_expiration() else key + Token.objects.filter(key=token.key).update( + key=key, created=datetime.datetime.now()) + return Response({'token': key}) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 38753c968..b095d8e03 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -50,6 +50,11 @@ DEFAULTS = { ), 'DEFAULT_CONTENT_NEGOTIATION_CLASS': 'rest_framework.negotiation.DefaultContentNegotiation', + 'DEFAULT_TOKEN_EXPIRE': { + 'is_expired': False, + # in days + 'expiration_time': 30 + }, # Genric view behavior 'DEFAULT_MODEL_SERIALIZER_CLASS': @@ -139,7 +144,8 @@ IMPORT_STRINGS = ( 'UNAUTHENTICATED_USER', 'UNAUTHENTICATED_TOKEN', 'VIEW_NAME_FUNCTION', - 'VIEW_DESCRIPTION_FUNCTION' + 'VIEW_DESCRIPTION_FUNCTION', + 'DEFAULT_TOKEN_EXPIRE' ) diff --git a/rest_framework/tests/test_authentication.py b/rest_framework/tests/test_authentication.py index a1c43d9ce..8d8dd6e21 100644 --- a/rest_framework/tests/test_authentication.py +++ b/rest_framework/tests/test_authentication.py @@ -24,6 +24,7 @@ from rest_framework.compat import oauth2_provider, oauth2_provider_scope from rest_framework.compat import oauth, oauth_provider from rest_framework.test import APIRequestFactory, APIClient from rest_framework.views import APIView +from rest_framework.settings import api_settings import base64 import time import datetime @@ -231,6 +232,28 @@ class TokenAuthTests(TestCase): self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.data['token'], self.key) + def test_token_expired(self): + """ Ensure token login view using expired token """ + api_settings.DEFAULT_TOKEN_EXPIRE['is_expired'] = True + client = APIClient(enforce_csrf_checks=True) + self.token.created = self.token.created - datetime.timedelta(days=40) + self.token.save() + response = client.post('/token/', {'example': 'example'}, + HTTP_AUTHORIZATION='Token %s' % self.token.key, + format='json') + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) + + def test_token_expire_after_renewal(self): + """ Ensure token renewes on next login after expiration """ + api_settings.DEFAULT_TOKEN_EXPIRE['is_expired'] = True + self.token.created = self.token.created - datetime.timedelta(days=40) + self.token.save() + client = APIClient(enforce_csrf_checks=True) + response = client.post('/auth-token/', {'username': self.username, + 'password': self.password}) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertNotEqual(response.data['token'], self.key) + class IncorrectCredentialsTests(TestCase): def test_incorrect_credentials(self):