diff --git a/rest_auth/registration/views.py b/rest_auth/registration/views.py index af6d7b6..fa95e7d 100644 --- a/rest_auth/registration/views.py +++ b/rest_auth/registration/views.py @@ -23,14 +23,20 @@ class RegisterView(CreateAPIView): permission_classes = (AllowAny, ) token_model = TokenModel + def get_response_data(self, user): + if allauth_settings.EMAIL_VERIFICATION == \ + allauth_settings.EmailVerificationMethod.MANDATORY: + return {} + + return TokenSerializer(user.auth_token).data + def create(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) user = self.perform_create(serializer) headers = self.get_success_headers(serializer.data) - return Response(TokenSerializer(user.auth_token).data, - status=status.HTTP_201_CREATED, - headers=headers) + + return Response(self.get_response_data(user), status=status.HTTP_201_CREATED, headers=headers) def perform_create(self, serializer): user = serializer.save(self.request) diff --git a/rest_auth/tests/test_api.py b/rest_auth/tests/test_api.py index 222b3a7..0e0ced2 100644 --- a/rest_auth/tests/test_api.py +++ b/rest_auth/tests/test_api.py @@ -310,8 +310,10 @@ class APITestCase1(TestCase, BaseAPITestCase): # test empty payload self.post(self.register_url, data={}, status_code=400) - self.post(self.register_url, data=self.REGISTRATION_DATA, status_code=201) + result = self.post(self.register_url, data=self.REGISTRATION_DATA, status_code=201) + self.assertIn('key', result.data) self.assertEqual(get_user_model().objects.all().count(), user_count + 1) + new_user = get_user_model().objects.latest('id') self.assertEqual(new_user.username, self.REGISTRATION_DATA['username']) @@ -339,11 +341,12 @@ class APITestCase1(TestCase, BaseAPITestCase): status_code=status.HTTP_400_BAD_REQUEST ) - self.post( + result = self.post( self.register_url, data=self.REGISTRATION_DATA_WITH_EMAIL, status_code=status.HTTP_201_CREATED ) + self.assertNotIn('key', result.data) self.assertEqual(get_user_model().objects.all().count(), user_count + 1) self.assertEqual(len(mail.outbox), mail_count + 1) new_user = get_user_model().objects.latest('id')