mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-07-28 00:49:49 +03:00
[fet] jwt authentication
1. Api to obtain jwt: rest_framework.authtoken.views.ObtainJSONWebToken 2. jwt based token authentication: rest_framework.authentication.JWTAuthentication
This commit is contained in:
parent
f81ca78642
commit
3a083fe844
|
@ -7,7 +7,8 @@ import binascii
|
|||
from django.contrib.auth import authenticate, get_user_model
|
||||
from django.middleware.csrf import CsrfViewMiddleware
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from django.conf import settings as django_settings
|
||||
from rest_framework.authtoken.handlers import JWTHandler
|
||||
from rest_framework import HTTP_HEADER_ENCODING, exceptions
|
||||
|
||||
|
||||
|
@ -223,3 +224,72 @@ class RemoteUserAuthentication(BaseAuthentication):
|
|||
user = authenticate(remote_user=request.META.get(self.header))
|
||||
if user and user.is_active:
|
||||
return (user, None)
|
||||
|
||||
|
||||
class JWTAuthentication(BaseAuthentication):
|
||||
"""
|
||||
Json Web Token authentication.
|
||||
|
||||
Clients should authenticate by passing the jwt key in the "Authorization"
|
||||
HTTP header, prepended with the string "Bearer ". For example:
|
||||
|
||||
Authorization: Bearer <json-web-token>
|
||||
"""
|
||||
|
||||
keyword = 'Bearer'
|
||||
_user_model = None
|
||||
|
||||
"""
|
||||
A custom token model may be used, but must have the following properties.
|
||||
|
||||
* key -- The string identifying the token
|
||||
* user -- The user to which the token belongs
|
||||
"""
|
||||
|
||||
def authenticate(self, request):
|
||||
auth = get_authorization_header(request).split()
|
||||
|
||||
if not auth or auth[0].lower() != self.keyword.lower().encode():
|
||||
return None
|
||||
|
||||
if len(auth) == 1:
|
||||
msg = _('Invalid token header. No credentials provided.')
|
||||
raise exceptions.AuthenticationFailed(msg)
|
||||
elif len(auth) > 2:
|
||||
msg = _('Invalid token header. Token string should not contain spaces.')
|
||||
raise exceptions.AuthenticationFailed(msg)
|
||||
|
||||
try:
|
||||
token = auth[1].decode()
|
||||
except UnicodeError:
|
||||
msg = _('Invalid token header. Token string should not contain invalid characters.')
|
||||
raise exceptions.AuthenticationFailed(msg)
|
||||
|
||||
return self.authenticate_credentials(token)
|
||||
|
||||
def authenticate_credentials(self, key):
|
||||
payload = JWTHandler.decode(token=key)
|
||||
user = self.get_user(payload)
|
||||
return (user, key)
|
||||
|
||||
def authenticate_header(self, request):
|
||||
return self.keyword
|
||||
|
||||
def get_user(self, payload):
|
||||
try:
|
||||
user = self.user_model.objects.get(pk=payload.get("user"))
|
||||
except self.user_model.DoesNotExist:
|
||||
raise exceptions.AuthenticationFailed("User Doesn't Exist.")
|
||||
|
||||
if not user.is_active:
|
||||
raise exceptions.AuthenticationFailed(_('User inactive or deleted.'))
|
||||
|
||||
return user
|
||||
|
||||
@property
|
||||
def user_model(self):
|
||||
if self._user_model is None:
|
||||
from django.apps import apps
|
||||
self._user_model = apps.get_model(*django_settings.AUTH_USER_MODEL.rsplit(".", 1))
|
||||
|
||||
return self._user_model
|
45
rest_framework/authtoken/handlers.py
Normal file
45
rest_framework/authtoken/handlers.py
Normal file
|
@ -0,0 +1,45 @@
|
|||
from typing import Dict
|
||||
from calendar import timegm
|
||||
from jwt import encode
|
||||
from jwt import decode
|
||||
from django.conf import settings as django_settings
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django.utils.timezone import now
|
||||
from rest_framework.settings import api_settings
|
||||
from rest_framework import exceptions
|
||||
|
||||
|
||||
|
||||
class JWTHandler:
|
||||
"""
|
||||
creates api specific tokens
|
||||
"""
|
||||
@classmethod
|
||||
def encode(cls, payload: Dict, secret=None, algorithm=None, headers=None, json_encoder=None):
|
||||
cls.set_expiration(payload)
|
||||
|
||||
_key = secret or django_settings.SECRET_KEY
|
||||
_algorithm = algorithm or api_settings.DEFAULT_JWT_ALGORITHM
|
||||
return encode(payload=payload, key=_key, algorithm=_algorithm, headers=headers, json_encoder=json_encoder)
|
||||
|
||||
@classmethod
|
||||
def decode(cls, token, secret=None, algorithm=None):
|
||||
_key = secret or django_settings.SECRET_KEY
|
||||
_algorithm = algorithm or api_settings.DEFAULT_JWT_ALGORITHM
|
||||
|
||||
try:
|
||||
_payload = decode(jwt=token, key=_key, algorithms=[_algorithm])
|
||||
except exceptions.APIException as e:
|
||||
raise exceptions.APIException(_(str(e)))
|
||||
|
||||
return _payload
|
||||
|
||||
@classmethod
|
||||
def set_expiration(self, payload, _name="exp", _from=None, _duration=None):
|
||||
"""
|
||||
Updates the expiration time of a token.
|
||||
"""
|
||||
_from = _from or now()
|
||||
_duration = _duration or api_settings.DEFAULT_JWT_DURATION
|
||||
_expiration_time = _from + _duration
|
||||
payload[_name] = timegm(_expiration_time.utctimetuple())
|
|
@ -1,5 +1,6 @@
|
|||
from rest_framework import parsers, renderers
|
||||
from rest_framework.authtoken.models import Token
|
||||
from rest_framework.authtoken.handlers import JWTHandler
|
||||
from rest_framework.authtoken.serializers import AuthTokenSerializer
|
||||
from rest_framework.compat import coreapi, coreschema
|
||||
from rest_framework.response import Response
|
||||
|
@ -46,5 +47,48 @@ class ObtainAuthToken(APIView):
|
|||
token, created = Token.objects.get_or_create(user=user)
|
||||
return Response({'token': token.key})
|
||||
|
||||
class ObtainJSONWebToken(APIView):
|
||||
throttle_classes = ()
|
||||
permission_classes = ()
|
||||
parser_classes = (parsers.FormParser, parsers.MultiPartParser, parsers.JSONParser,)
|
||||
renderer_classes = (renderers.JSONRenderer,)
|
||||
serializer_class = AuthTokenSerializer
|
||||
if coreapi is not None and coreschema is not None:
|
||||
schema = ManualSchema(
|
||||
fields=[
|
||||
coreapi.Field(
|
||||
name="username",
|
||||
required=True,
|
||||
location='form',
|
||||
schema=coreschema.String(
|
||||
title="Username",
|
||||
description="Valid username for authentication",
|
||||
),
|
||||
),
|
||||
coreapi.Field(
|
||||
name="password",
|
||||
required=True,
|
||||
location='form',
|
||||
schema=coreschema.String(
|
||||
title="Password",
|
||||
description="Valid password for authentication",
|
||||
),
|
||||
),
|
||||
],
|
||||
encoding="application/json",
|
||||
)
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
serializer = self.serializer_class(data=request.data,
|
||||
context={'request': request})
|
||||
serializer.is_valid(raise_exception=True)
|
||||
user = serializer.validated_data['user']
|
||||
key = JWTHandler.encode({
|
||||
"user": user.pk
|
||||
})
|
||||
return Response({'token': key})
|
||||
|
||||
|
||||
obtain_json_web_token = ObtainJSONWebToken.as_view()
|
||||
|
||||
obtain_auth_token = ObtainAuthToken.as_view()
|
||||
|
|
|
@ -18,6 +18,7 @@ This module provides the `api_setting` object, that is used to access
|
|||
REST framework settings, checking for user settings first, then falling
|
||||
back to the defaults.
|
||||
"""
|
||||
from datetime import timedelta
|
||||
from django.conf import settings
|
||||
from django.test.signals import setting_changed
|
||||
from django.utils.module_loading import import_string
|
||||
|
@ -124,6 +125,8 @@ DEFAULTS = {
|
|||
'retrieve': 'read',
|
||||
'destroy': 'delete'
|
||||
},
|
||||
"DEFAULT_JWT_ALGORITHM": "HS256",
|
||||
"DEFAULT_JWT_DURATION": timedelta(minutes=5),
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ from rest_framework.authentication import (
|
|||
SessionAuthentication, TokenAuthentication
|
||||
)
|
||||
from rest_framework.authtoken.models import Token
|
||||
from rest_framework.authtoken.views import obtain_auth_token
|
||||
from rest_framework.authtoken.views import obtain_auth_token, obtain_json_web_token
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.test import APIClient, APIRequestFactory
|
||||
from rest_framework.views import APIView
|
||||
|
@ -74,6 +74,7 @@ urlpatterns = [
|
|||
)
|
||||
),
|
||||
url(r'^auth-token/$', obtain_auth_token),
|
||||
url(r'^auth-jwt/$', obtain_json_web_token),
|
||||
url(r'^auth/', include('rest_framework.urls', namespace='rest_framework')),
|
||||
]
|
||||
|
||||
|
@ -420,6 +421,42 @@ class TokenAuthTests(BaseTokenAuthTests, TestCase):
|
|||
assert response.data['token'] == self.key
|
||||
|
||||
|
||||
@override_settings(ROOT_URLCONF=__name__)
|
||||
class JSONWebTokenAuthTests(BaseTokenAuthTests, TestCase):
|
||||
path = '/token/'
|
||||
|
||||
def test_token_login_json(self):
|
||||
"""Ensure token login view using JSON POST works."""
|
||||
client = APIClient(enforce_csrf_checks=True)
|
||||
response = client.post(
|
||||
'/auth-jwt/',
|
||||
{'username': self.username, 'password': self.password},
|
||||
format='json'
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert "token" in response.data
|
||||
|
||||
def test_token_login_json_bad_creds(self):
|
||||
"""
|
||||
Ensure token login view using JSON POST fails if
|
||||
bad credentials are used
|
||||
"""
|
||||
client = APIClient(enforce_csrf_checks=True)
|
||||
response = client.post(
|
||||
'/auth-jwt/',
|
||||
{'username': self.username, 'password': "badpass"},
|
||||
format='json'
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_token_login_json_missing_fields(self):
|
||||
"""Ensure token login view using JSON POST fails if missing fields."""
|
||||
client = APIClient(enforce_csrf_checks=True)
|
||||
response = client.post('/auth-jwt/',
|
||||
{'username': self.username}, format='json')
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@override_settings(ROOT_URLCONF=__name__)
|
||||
class CustomTokenAuthTests(BaseTokenAuthTests, TestCase):
|
||||
model = CustomToken
|
||||
|
|
Loading…
Reference in New Issue
Block a user