diff --git a/rest_auth/app_settings.py b/rest_auth/app_settings.py index 349864d..d868175 100644 --- a/rest_auth/app_settings.py +++ b/rest_auth/app_settings.py @@ -37,5 +37,3 @@ PasswordChangeSerializer = import_callable( serializers.get('PASSWORD_CHANGE_SERIALIZER', DefaultPasswordChangeSerializer) ) - - diff --git a/rest_auth/serializers.py b/rest_auth/serializers.py index 0fdc5f4..24dedef 100644 --- a/rest_auth/serializers.py +++ b/rest_auth/serializers.py @@ -121,14 +121,31 @@ class PasswordResetConfirmSerializer(serializers.Serializer): class PasswordChangeSerializer(serializers.Serializer): + old_password = serializers.CharField(max_length=128) new_password1 = serializers.CharField(max_length=128) new_password2 = serializers.CharField(max_length=128) set_password_form_class = SetPasswordForm + def __init__(self, *args, **kwargs): + self.old_password_field_enabled = getattr(settings, + 'OLD_PASSWORD_FIELD_ENABLED', False) + super(PasswordChangeSerializer, self).__init__(*args, **kwargs) + + if not self.old_password_field_enabled: + self.fields.pop('old_password') + + self.request = self.context.get('request') + self.user = self.request.user + + def validate_old_password(self, attrs, source): + if self.old_password_field_enabled and \ + not self.user.check_password(attrs.get(source, '')): + raise serializers.ValidationError('Invalid password') + return attrs + def validate(self, attrs): - request = self.context.get('request') - self.set_password_form = self.set_password_form_class(user=request.user, + self.set_password_form = self.set_password_form_class(user=self.user, data=attrs) if not self.set_password_form.is_valid(): diff --git a/rest_auth/tests.py b/rest_auth/tests.py index 4f66124..1134940 100644 --- a/rest_auth/tests.py +++ b/rest_auth/tests.py @@ -248,6 +248,40 @@ class APITestCase1(TestCase, BaseAPITestCase): # send empty payload self.post(self.password_change_url, data={}, status_code=400) + @override_settings(OLD_PASSWORD_FIELD_ENABLED=True) + def test_password_change_with_old_password(self): + login_payload = { + "username": self.USERNAME, + "password": self.PASS + } + User.objects.create_user(self.USERNAME, '', self.PASS) + self.post(self.login_url, data=login_payload, status_code=200) + self.token = self.response.json['key'] + + new_password_payload = { + "old_password": "%s!" % self.PASS, # wrong password + "new_password1": "new_person", + "new_password2": "new_person" + } + self.post(self.password_change_url, data=new_password_payload, + status_code=400) + + new_password_payload = { + "old_password": self.PASS, + "new_password1": "new_person", + "new_password2": "new_person" + } + self.post(self.password_change_url, data=new_password_payload, + status_code=200) + + # user should not be able to login using old password + self.post(self.login_url, data=login_payload, status_code=400) + + # new password should work + login_payload['password'] = new_password_payload['new_password1'] + self.post(self.login_url, data=login_payload, status_code=200) + + def test_password_reset(self): user = User.objects.create_user(self.USERNAME, self.EMAIL, self.PASS)