From 3a083fe844ba8ef75a0fef3c3dba3f6cc801c6c8 Mon Sep 17 00:00:00 2001 From: vikas soni Date: Sun, 16 Feb 2020 00:35:19 +0530 Subject: [PATCH] [fet] jwt authentication 1. Api to obtain jwt: rest_framework.authtoken.views.ObtainJSONWebToken 2. jwt based token authentication: rest_framework.authentication.JWTAuthentication --- rest_framework/authentication.py | 72 ++++++++++++++++++++- rest_framework/authtoken/handlers.py | 45 +++++++++++++ rest_framework/authtoken/views.py | 44 +++++++++++++ rest_framework/settings.py | 3 + tests/authentication/test_authentication.py | 39 ++++++++++- 5 files changed, 201 insertions(+), 2 deletions(-) create mode 100644 rest_framework/authtoken/handlers.py diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index 1e30728d3..40f5f8277 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -7,7 +7,8 @@ import binascii from django.contrib.auth import authenticate, get_user_model from django.middleware.csrf import CsrfViewMiddleware from django.utils.translation import gettext_lazy as _ - +from django.conf import settings as django_settings +from rest_framework.authtoken.handlers import JWTHandler from rest_framework import HTTP_HEADER_ENCODING, exceptions @@ -223,3 +224,72 @@ class RemoteUserAuthentication(BaseAuthentication): user = authenticate(remote_user=request.META.get(self.header)) if user and user.is_active: return (user, None) + + +class JWTAuthentication(BaseAuthentication): + """ + Json Web Token authentication. + + Clients should authenticate by passing the jwt key in the "Authorization" + HTTP header, prepended with the string "Bearer ". For example: + + Authorization: Bearer + """ + + keyword = 'Bearer' + _user_model = None + + """ + A custom token model may be used, but must have the following properties. + + * key -- The string identifying the token + * user -- The user to which the token belongs + """ + + def authenticate(self, request): + auth = get_authorization_header(request).split() + + if not auth or auth[0].lower() != self.keyword.lower().encode(): + return None + + if len(auth) == 1: + msg = _('Invalid token header. No credentials provided.') + raise exceptions.AuthenticationFailed(msg) + elif len(auth) > 2: + msg = _('Invalid token header. Token string should not contain spaces.') + raise exceptions.AuthenticationFailed(msg) + + try: + token = auth[1].decode() + except UnicodeError: + msg = _('Invalid token header. Token string should not contain invalid characters.') + raise exceptions.AuthenticationFailed(msg) + + return self.authenticate_credentials(token) + + def authenticate_credentials(self, key): + payload = JWTHandler.decode(token=key) + user = self.get_user(payload) + return (user, key) + + def authenticate_header(self, request): + return self.keyword + + def get_user(self, payload): + try: + user = self.user_model.objects.get(pk=payload.get("user")) + except self.user_model.DoesNotExist: + raise exceptions.AuthenticationFailed("User Doesn't Exist.") + + if not user.is_active: + raise exceptions.AuthenticationFailed(_('User inactive or deleted.')) + + return user + + @property + def user_model(self): + if self._user_model is None: + from django.apps import apps + self._user_model = apps.get_model(*django_settings.AUTH_USER_MODEL.rsplit(".", 1)) + + return self._user_model \ No newline at end of file diff --git a/rest_framework/authtoken/handlers.py b/rest_framework/authtoken/handlers.py new file mode 100644 index 000000000..b3c67c9eb --- /dev/null +++ b/rest_framework/authtoken/handlers.py @@ -0,0 +1,45 @@ +from typing import Dict +from calendar import timegm +from jwt import encode +from jwt import decode +from django.conf import settings as django_settings +from django.utils.translation import gettext_lazy as _ +from django.utils.timezone import now +from rest_framework.settings import api_settings +from rest_framework import exceptions + + + +class JWTHandler: + """ + creates api specific tokens + """ + @classmethod + def encode(cls, payload: Dict, secret=None, algorithm=None, headers=None, json_encoder=None): + cls.set_expiration(payload) + + _key = secret or django_settings.SECRET_KEY + _algorithm = algorithm or api_settings.DEFAULT_JWT_ALGORITHM + return encode(payload=payload, key=_key, algorithm=_algorithm, headers=headers, json_encoder=json_encoder) + + @classmethod + def decode(cls, token, secret=None, algorithm=None): + _key = secret or django_settings.SECRET_KEY + _algorithm = algorithm or api_settings.DEFAULT_JWT_ALGORITHM + + try: + _payload = decode(jwt=token, key=_key, algorithms=[_algorithm]) + except exceptions.APIException as e: + raise exceptions.APIException(_(str(e))) + + return _payload + + @classmethod + def set_expiration(self, payload, _name="exp", _from=None, _duration=None): + """ + Updates the expiration time of a token. + """ + _from = _from or now() + _duration = _duration or api_settings.DEFAULT_JWT_DURATION + _expiration_time = _from + _duration + payload[_name] = timegm(_expiration_time.utctimetuple()) diff --git a/rest_framework/authtoken/views.py b/rest_framework/authtoken/views.py index a8c751d51..72cff3114 100644 --- a/rest_framework/authtoken/views.py +++ b/rest_framework/authtoken/views.py @@ -1,5 +1,6 @@ from rest_framework import parsers, renderers from rest_framework.authtoken.models import Token +from rest_framework.authtoken.handlers import JWTHandler from rest_framework.authtoken.serializers import AuthTokenSerializer from rest_framework.compat import coreapi, coreschema from rest_framework.response import Response @@ -46,5 +47,48 @@ class ObtainAuthToken(APIView): token, created = Token.objects.get_or_create(user=user) return Response({'token': token.key}) +class ObtainJSONWebToken(APIView): + throttle_classes = () + permission_classes = () + parser_classes = (parsers.FormParser, parsers.MultiPartParser, parsers.JSONParser,) + renderer_classes = (renderers.JSONRenderer,) + serializer_class = AuthTokenSerializer + if coreapi is not None and coreschema is not None: + schema = ManualSchema( + fields=[ + coreapi.Field( + name="username", + required=True, + location='form', + schema=coreschema.String( + title="Username", + description="Valid username for authentication", + ), + ), + coreapi.Field( + name="password", + required=True, + location='form', + schema=coreschema.String( + title="Password", + description="Valid password for authentication", + ), + ), + ], + encoding="application/json", + ) + + def post(self, request, *args, **kwargs): + serializer = self.serializer_class(data=request.data, + context={'request': request}) + serializer.is_valid(raise_exception=True) + user = serializer.validated_data['user'] + key = JWTHandler.encode({ + "user": user.pk + }) + return Response({'token': key}) + + +obtain_json_web_token = ObtainJSONWebToken.as_view() obtain_auth_token = ObtainAuthToken.as_view() diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 9eb4c5653..3c69e9a4e 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -18,6 +18,7 @@ This module provides the `api_setting` object, that is used to access REST framework settings, checking for user settings first, then falling back to the defaults. """ +from datetime import timedelta from django.conf import settings from django.test.signals import setting_changed from django.utils.module_loading import import_string @@ -124,6 +125,8 @@ DEFAULTS = { 'retrieve': 'read', 'destroy': 'delete' }, + "DEFAULT_JWT_ALGORITHM": "HS256", + "DEFAULT_JWT_DURATION": timedelta(minutes=5), } diff --git a/tests/authentication/test_authentication.py b/tests/authentication/test_authentication.py index 37e265e17..8cc15381b 100644 --- a/tests/authentication/test_authentication.py +++ b/tests/authentication/test_authentication.py @@ -15,7 +15,7 @@ from rest_framework.authentication import ( SessionAuthentication, TokenAuthentication ) from rest_framework.authtoken.models import Token -from rest_framework.authtoken.views import obtain_auth_token +from rest_framework.authtoken.views import obtain_auth_token, obtain_json_web_token from rest_framework.response import Response from rest_framework.test import APIClient, APIRequestFactory from rest_framework.views import APIView @@ -74,6 +74,7 @@ urlpatterns = [ ) ), url(r'^auth-token/$', obtain_auth_token), + url(r'^auth-jwt/$', obtain_json_web_token), url(r'^auth/', include('rest_framework.urls', namespace='rest_framework')), ] @@ -420,6 +421,42 @@ class TokenAuthTests(BaseTokenAuthTests, TestCase): assert response.data['token'] == self.key +@override_settings(ROOT_URLCONF=__name__) +class JSONWebTokenAuthTests(BaseTokenAuthTests, TestCase): + path = '/token/' + + def test_token_login_json(self): + """Ensure token login view using JSON POST works.""" + client = APIClient(enforce_csrf_checks=True) + response = client.post( + '/auth-jwt/', + {'username': self.username, 'password': self.password}, + format='json' + ) + assert response.status_code == status.HTTP_200_OK + assert "token" in response.data + + def test_token_login_json_bad_creds(self): + """ + Ensure token login view using JSON POST fails if + bad credentials are used + """ + client = APIClient(enforce_csrf_checks=True) + response = client.post( + '/auth-jwt/', + {'username': self.username, 'password': "badpass"}, + format='json' + ) + assert response.status_code == 400 + + def test_token_login_json_missing_fields(self): + """Ensure token login view using JSON POST fails if missing fields.""" + client = APIClient(enforce_csrf_checks=True) + response = client.post('/auth-jwt/', + {'username': self.username}, format='json') + assert response.status_code == 400 + + @override_settings(ROOT_URLCONF=__name__) class CustomTokenAuthTests(BaseTokenAuthTests, TestCase): model = CustomToken