diff --git a/docs/configuration.rst b/docs/configuration.rst index 079c710..3746234 100644 --- a/docs/configuration.rst +++ b/docs/configuration.rst @@ -36,11 +36,12 @@ Configuration - REGISTER_SERIALIZER - serializer class in ``rest_auth.register.views.RegisterView``, default value ``rest_auth.register.serializers.RegisterSerializer`` +- **REST_AUTH_TOKEN_MODEL** - model class for tokens, default value ``rest_framework.authtoken.models`` +- **REST_AUTH_TOKEN_CREATOR** - callable to create tokens, default value ``rest_auth.utils.default_create_token``. - **REST_SESSION_LOGIN** - Enable session login in Login API view (default: True) - - **OLD_PASSWORD_FIELD_ENABLED** - set it to True if you want to have old password verification on password change enpoint (default: False) - **LOGOUT_ON_PASSWORD_CHANGE** - set to False if you want to keep the current user logged in after a password change diff --git a/rest_auth/app_settings.py b/rest_auth/app_settings.py index e0340b7..b77d1d2 100644 --- a/rest_auth/app_settings.py +++ b/rest_auth/app_settings.py @@ -7,8 +7,10 @@ from rest_auth.serializers import ( PasswordResetSerializer as DefaultPasswordResetSerializer, PasswordResetConfirmSerializer as DefaultPasswordResetConfirmSerializer, PasswordChangeSerializer as DefaultPasswordChangeSerializer) -from .utils import import_callable +from .utils import import_callable, default_create_token +create_token = import_callable( + getattr(settings, 'REST_AUTH_TOKEN_CREATOR', default_create_token)) serializers = getattr(settings, 'REST_AUTH_SERIALIZERS', {}) diff --git a/rest_auth/models.py b/rest_auth/models.py index e703865..a132f9c 100644 --- a/rest_auth/models.py +++ b/rest_auth/models.py @@ -1,3 +1,10 @@ -# from django.db import models +from django.conf import settings + +from rest_framework.authtoken.models import Token as DefaultTokenModel + +from .utils import import_callable # Register your models here. + +TokenModel = import_callable( + getattr(settings, 'REST_AUTH_TOKEN_MODEL', DefaultTokenModel)) diff --git a/rest_auth/registration/views.py b/rest_auth/registration/views.py index 81fd951..af6d7b6 100644 --- a/rest_auth/registration/views.py +++ b/rest_auth/registration/views.py @@ -3,23 +3,25 @@ from rest_framework.response import Response from rest_framework.permissions import AllowAny from rest_framework.generics import CreateAPIView from rest_framework import status -from rest_framework.authtoken.models import Token from rest_framework.exceptions import MethodNotAllowed from allauth.account.views import ConfirmEmailView from allauth.account.utils import complete_signup from allauth.account import app_settings as allauth_settings -from rest_auth.app_settings import TokenSerializer +from rest_auth.app_settings import (TokenSerializer, + create_token) from rest_auth.registration.serializers import (SocialLoginSerializer, VerifyEmailSerializer) -from .app_settings import RegisterSerializer from rest_auth.views import LoginView +from rest_auth.models import TokenModel +from .app_settings import RegisterSerializer class RegisterView(CreateAPIView): serializer_class = RegisterSerializer permission_classes = (AllowAny, ) + token_model = TokenModel def create(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.data) @@ -32,7 +34,7 @@ class RegisterView(CreateAPIView): def perform_create(self, serializer): user = serializer.save(self.request) - Token.objects.get_or_create(user=user) + create_token(self.token_model, user, serializer) complete_signup(self.request._request, user, allauth_settings.EMAIL_VERIFICATION, None) diff --git a/rest_auth/serializers.py b/rest_auth/serializers.py index 0fd6eab..735c188 100644 --- a/rest_auth/serializers.py +++ b/rest_auth/serializers.py @@ -6,8 +6,9 @@ from django.utils.http import urlsafe_base64_decode as uid_decoder from django.utils.translation import ugettext_lazy as _ from django.utils.encoding import force_text +from .models import TokenModel + from rest_framework import serializers, exceptions -from rest_framework.authtoken.models import Token from rest_framework.exceptions import ValidationError # Get the UserModel @@ -114,7 +115,7 @@ class TokenSerializer(serializers.ModelSerializer): """ class Meta: - model = Token + model = TokenModel fields = ('key',) diff --git a/rest_auth/utils.py b/rest_auth/utils.py index a32da60..e224f26 100644 --- a/rest_auth/utils.py +++ b/rest_auth/utils.py @@ -9,3 +9,8 @@ def import_callable(path_or_callable): assert isinstance(path_or_callable, string_types) package, attr = path_or_callable.rsplit('.', 1) return getattr(import_module(package), attr) + + +def default_create_token(token_model, user, serializer): + token, _ = token_model.objects.get_or_create(user=user) + return token diff --git a/rest_auth/views.py b/rest_auth/views.py index 3af1557..3bb6f6b 100644 --- a/rest_auth/views.py +++ b/rest_auth/views.py @@ -7,14 +7,14 @@ from rest_framework.views import APIView from rest_framework.response import Response from rest_framework.generics import GenericAPIView from rest_framework.permissions import IsAuthenticated, AllowAny -from rest_framework.authtoken.models import Token from rest_framework.generics import RetrieveUpdateAPIView from .app_settings import ( TokenSerializer, UserDetailsSerializer, LoginSerializer, PasswordResetSerializer, PasswordResetConfirmSerializer, - PasswordChangeSerializer + PasswordChangeSerializer, create_token ) +from .models import TokenModel class LoginView(GenericAPIView): @@ -30,13 +30,12 @@ class LoginView(GenericAPIView): """ permission_classes = (AllowAny,) serializer_class = LoginSerializer - token_model = Token + token_model = TokenModel response_serializer = TokenSerializer def login(self): self.user = self.serializer.validated_data['user'] - self.token, created = self.token_model.objects.get_or_create( - user=self.user) + self.token = create_token(self.token_model, self.user, self.serializer) if getattr(settings, 'REST_SESSION_LOGIN', True): login(self.request, self.user)