diff --git a/rest_auth/registration/views.py b/rest_auth/registration/views.py index d6638b6..fd80e82 100644 --- a/rest_auth/registration/views.py +++ b/rest_auth/registration/views.py @@ -14,7 +14,8 @@ from allauth.account.views import ConfirmEmailView from allauth.account.utils import complete_signup from allauth.account import app_settings as allauth_settings -from rest_auth.app_settings import (TokenSerializer, +from rest_auth.app_settings import (UserDetailsSerializer, + TokenSerializer, JWTSerializer, create_token) from rest_auth.models import TokenModel @@ -49,8 +50,9 @@ class RegisterView(CreateAPIView): 'token': self.token } return JWTSerializer(data).data - else: + elif getattr(settings, 'REST_USE_TOKEN', True): return TokenSerializer(user.auth_token).data + return UserDetailsSerializer(user).data def create(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.data) @@ -66,12 +68,13 @@ class RegisterView(CreateAPIView): user = serializer.save(self.request) if getattr(settings, 'REST_USE_JWT', False): self.token = jwt_encode(user) - else: + elif getattr(settings, 'REST_USE_TOKEN', True): create_token(self.token_model, user, serializer) complete_signup(self.request._request, user, allauth_settings.EMAIL_VERIFICATION, None) + return user diff --git a/rest_auth/tests/test_api.py b/rest_auth/tests/test_api.py index 1b8fad0..c2cc4f3 100644 --- a/rest_auth/tests/test_api.py +++ b/rest_auth/tests/test_api.py @@ -426,6 +426,18 @@ class APITestCase1(TestCase, BaseAPITestCase): self._login() self._logout() + @override_settings(REST_USE_TOKEN=False) + def test_registration_without_token(self): + user_count = get_user_model().objects.all().count() + + self.post(self.register_url, data=self.REGISTRATION_DATA_WITH_EMAIL, status_code=201) + self.assertEqual(self.response.json['username'], self.USERNAME) + self.assertEqual(self.response.json['email'], self.EMAIL) + + self.assertEqual(get_user_model().objects.all().count(), user_count + 1) + self._login() + self._logout() + def test_registration_with_invalid_password(self): data = self.REGISTRATION_DATA.copy() data['password2'] = 'foobar'