diff --git a/docs/configuration.rst b/docs/configuration.rst index 282f326..336abb5 100644 --- a/docs/configuration.rst +++ b/docs/configuration.rst @@ -29,10 +29,12 @@ Configuration ... } +- **REST_AUTH_TOKEN_MODEL** - model class for tokens, default value ``rest_framework.authtoken.models`` + +- **REST_AUTH_TOKEN_CREATOR** - callable to create tokens, default value ``rest_auth.utils.default_create_token``. - **REST_SESSION_LOGIN** - Enable session login in Login API view (default: True) - - **OLD_PASSWORD_FIELD_ENABLED** - set it to True if you want to have old password verification on password change enpoint (default: False) - **LOGOUT_ON_PASSWORD_CHANGE** - set to False if you want to keep the current user logged in after a password change diff --git a/rest_auth/app_settings.py b/rest_auth/app_settings.py index e0340b7..b77d1d2 100644 --- a/rest_auth/app_settings.py +++ b/rest_auth/app_settings.py @@ -7,8 +7,10 @@ from rest_auth.serializers import ( PasswordResetSerializer as DefaultPasswordResetSerializer, PasswordResetConfirmSerializer as DefaultPasswordResetConfirmSerializer, PasswordChangeSerializer as DefaultPasswordChangeSerializer) -from .utils import import_callable +from .utils import import_callable, default_create_token +create_token = import_callable( + getattr(settings, 'REST_AUTH_TOKEN_CREATOR', default_create_token)) serializers = getattr(settings, 'REST_AUTH_SERIALIZERS', {}) diff --git a/rest_auth/models.py b/rest_auth/models.py index e703865..a132f9c 100644 --- a/rest_auth/models.py +++ b/rest_auth/models.py @@ -1,3 +1,10 @@ -# from django.db import models +from django.conf import settings + +from rest_framework.authtoken.models import Token as DefaultTokenModel + +from .utils import import_callable # Register your models here. + +TokenModel = import_callable( + getattr(settings, 'REST_AUTH_TOKEN_MODEL', DefaultTokenModel)) diff --git a/rest_auth/registration/views.py b/rest_auth/registration/views.py index e700706..57a1327 100644 --- a/rest_auth/registration/views.py +++ b/rest_auth/registration/views.py @@ -3,7 +3,6 @@ from rest_framework.views import APIView from rest_framework.response import Response from rest_framework.permissions import AllowAny from rest_framework import status -from rest_framework.authtoken.models import Token from allauth.account.views import SignupView, ConfirmEmailView from allauth.account.utils import complete_signup @@ -12,6 +11,7 @@ from allauth.account import app_settings from rest_auth.app_settings import TokenSerializer from rest_auth.registration.serializers import SocialLoginSerializer from rest_auth.views import LoginView +from rest_auth.models import TokenModel class RegisterView(APIView, SignupView): @@ -27,7 +27,7 @@ class RegisterView(APIView, SignupView): permission_classes = (AllowAny,) allowed_methods = ('POST', 'OPTIONS', 'HEAD') - token_model = Token + token_model = TokenModel serializer_class = TokenSerializer def get(self, *args, **kwargs): diff --git a/rest_auth/serializers.py b/rest_auth/serializers.py index a2d1a82..2dd92b8 100644 --- a/rest_auth/serializers.py +++ b/rest_auth/serializers.py @@ -6,8 +6,9 @@ from django.utils.http import urlsafe_base64_decode as uid_decoder from django.utils.translation import ugettext_lazy as _ from django.utils.encoding import force_text +from .models import TokenModel + from rest_framework import serializers, exceptions -from rest_framework.authtoken.models import Token from rest_framework.exceptions import ValidationError # Get the UserModel @@ -84,7 +85,7 @@ class TokenSerializer(serializers.ModelSerializer): """ class Meta: - model = Token + model = TokenModel fields = ('key',) diff --git a/rest_auth/utils.py b/rest_auth/utils.py index a32da60..e5bbf7c 100644 --- a/rest_auth/utils.py +++ b/rest_auth/utils.py @@ -9,3 +9,9 @@ def import_callable(path_or_callable): assert isinstance(path_or_callable, string_types) package, attr = path_or_callable.rsplit('.', 1) return getattr(import_module(package), attr) + + +def default_create_token(token_model, serializer): + user = serializer.validated_data['user'] + token, _ = token_model.objects.get_or_create(user=user) + return token diff --git a/rest_auth/views.py b/rest_auth/views.py index 3af1557..6f3c413 100644 --- a/rest_auth/views.py +++ b/rest_auth/views.py @@ -7,14 +7,14 @@ from rest_framework.views import APIView from rest_framework.response import Response from rest_framework.generics import GenericAPIView from rest_framework.permissions import IsAuthenticated, AllowAny -from rest_framework.authtoken.models import Token from rest_framework.generics import RetrieveUpdateAPIView from .app_settings import ( TokenSerializer, UserDetailsSerializer, LoginSerializer, PasswordResetSerializer, PasswordResetConfirmSerializer, - PasswordChangeSerializer + PasswordChangeSerializer, create_token ) +from .models import TokenModel class LoginView(GenericAPIView): @@ -30,13 +30,12 @@ class LoginView(GenericAPIView): """ permission_classes = (AllowAny,) serializer_class = LoginSerializer - token_model = Token + token_model = TokenModel response_serializer = TokenSerializer def login(self): self.user = self.serializer.validated_data['user'] - self.token, created = self.token_model.objects.get_or_create( - user=self.user) + self.token = create_token(self.token_model, self.serializer) if getattr(settings, 'REST_SESSION_LOGIN', True): login(self.request, self.user)