mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-08-13 16:54:47 +03:00
Fix #9250: Prevent token overwrite and improve security
- Fix key collision issue that could overwrite existing tokens - Use force_insert=True only for new token instances - Replace os.urandom with secrets.token_hex for better security - Add comprehensive test suite to verify fix and backward compatibility - Ensure existing tokens can still be updated without breaking changes
This commit is contained in:
parent
de018df2aa
commit
d381901de4
|
@ -1,5 +1,4 @@
|
||||||
import binascii
|
import secrets
|
||||||
import os
|
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.db import models
|
from django.db import models
|
||||||
|
@ -28,13 +27,32 @@ class Token(models.Model):
|
||||||
verbose_name_plural = _("Tokens")
|
verbose_name_plural = _("Tokens")
|
||||||
|
|
||||||
def save(self, *args, **kwargs):
|
def save(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Save the token instance.
|
||||||
|
|
||||||
|
If no key is provided, generates a cryptographically secure key.
|
||||||
|
For existing tokens with cleared keys, regenerates the key.
|
||||||
|
For new tokens, ensures they are inserted as new (not updated).
|
||||||
|
"""
|
||||||
if not self.key:
|
if not self.key:
|
||||||
self.key = self.generate_key()
|
self.key = self.generate_key()
|
||||||
|
# For new objects, force INSERT to prevent overwriting existing tokens
|
||||||
|
if self._state.adding:
|
||||||
|
kwargs['force_insert'] = True
|
||||||
return super().save(*args, **kwargs)
|
return super().save(*args, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate_key(cls):
|
def generate_key(cls):
|
||||||
return binascii.hexlify(os.urandom(20)).decode()
|
"""
|
||||||
|
Generate a cryptographically secure token key.
|
||||||
|
|
||||||
|
Uses secrets.token_hex(20) which provides 40 hexadecimal characters
|
||||||
|
(160 bits of entropy) suitable for authentication tokens.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: A 40-character hexadecimal string
|
||||||
|
"""
|
||||||
|
return secrets.token_hex(20)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.key
|
return self.key
|
||||||
|
|
|
@ -1,10 +1,13 @@
|
||||||
import importlib
|
import importlib
|
||||||
|
import secrets
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from django.contrib.admin import site
|
from django.contrib.admin import site
|
||||||
from django.contrib.auth.models import User
|
from django.contrib.auth.models import User
|
||||||
from django.core.management import CommandError, call_command
|
from django.core.management import CommandError, call_command
|
||||||
|
from django.db import IntegrityError
|
||||||
from django.test import TestCase, modify_settings
|
from django.test import TestCase, modify_settings
|
||||||
|
|
||||||
from rest_framework.authtoken.admin import TokenAdmin
|
from rest_framework.authtoken.admin import TokenAdmin
|
||||||
|
@ -19,8 +22,13 @@ class AuthTokenTests(TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.site = site
|
self.site = site
|
||||||
self.user = User.objects.create_user(username='test_user')
|
# CORRECTED: Only create the user. Each test will now create its own
|
||||||
self.token = Token.objects.create(key='test token', user=self.user)
|
# token(s) to ensure proper test isolation.
|
||||||
|
self.user = User.objects.create_user(
|
||||||
|
username='test_user',
|
||||||
|
email='test@example.com',
|
||||||
|
password='password'
|
||||||
|
)
|
||||||
|
|
||||||
def test_authtoken_can_be_imported_when_not_included_in_installed_apps(self):
|
def test_authtoken_can_be_imported_when_not_included_in_installed_apps(self):
|
||||||
import rest_framework.authtoken.models
|
import rest_framework.authtoken.models
|
||||||
|
@ -31,12 +39,16 @@ class AuthTokenTests(TestCase):
|
||||||
importlib.reload(rest_framework.authtoken.models)
|
importlib.reload(rest_framework.authtoken.models)
|
||||||
|
|
||||||
def test_model_admin_displayed_fields(self):
|
def test_model_admin_displayed_fields(self):
|
||||||
|
# Create a token specifically for this test.
|
||||||
|
token = Token.objects.create(user=self.user)
|
||||||
mock_request = object()
|
mock_request = object()
|
||||||
token_admin = TokenAdmin(self.token, self.site)
|
token_admin = TokenAdmin(token, self.site)
|
||||||
assert token_admin.get_fields(mock_request) == ('user',)
|
assert token_admin.get_fields(mock_request) == ('user',)
|
||||||
|
|
||||||
def test_token_string_representation(self):
|
def test_token_string_representation(self):
|
||||||
assert str(self.token) == 'test token'
|
# Create a token with a known key specifically for this test.
|
||||||
|
token = Token.objects.create(key='test token', user=self.user)
|
||||||
|
assert str(token) == 'test token'
|
||||||
|
|
||||||
def test_validate_raise_error_if_no_credentials_provided(self):
|
def test_validate_raise_error_if_no_credentials_provided(self):
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
|
@ -48,6 +60,68 @@ class AuthTokenTests(TestCase):
|
||||||
self.user.save()
|
self.user.save()
|
||||||
assert AuthTokenSerializer(data=data).is_valid()
|
assert AuthTokenSerializer(data=data).is_valid()
|
||||||
|
|
||||||
|
# --- Tests for Issue #9250 and secrets module refactor ---
|
||||||
|
|
||||||
|
def test_token_string_representation_is_randomly_generated_key(self):
|
||||||
|
"""
|
||||||
|
Ensure the string representation of a token is its key when auto-generated.
|
||||||
|
"""
|
||||||
|
token = Token.objects.create(user=self.user)
|
||||||
|
self.assertEqual(str(token), token.key)
|
||||||
|
|
||||||
|
def test_token_creation_collision_raises_integrity_error(self):
|
||||||
|
"""
|
||||||
|
Verify that creating a token with an existing key raises IntegrityError.
|
||||||
|
"""
|
||||||
|
user2 = User.objects.create_user('user2', 'user2@example.com', 'p')
|
||||||
|
existing_token = Token.objects.create(user=user2)
|
||||||
|
|
||||||
|
# Try to create another token with the same key
|
||||||
|
with self.assertRaises(IntegrityError):
|
||||||
|
Token.objects.create(key=existing_token.key, user=self.user)
|
||||||
|
|
||||||
|
def test_key_regeneration_on_save_is_not_a_breaking_change(self):
|
||||||
|
"""
|
||||||
|
Verify that when a token is created without a key, it generates one correctly.
|
||||||
|
This tests the backward compatibility scenario where existing code might
|
||||||
|
create tokens without explicitly setting a key.
|
||||||
|
"""
|
||||||
|
# Create a token without a key - it should generate one automatically
|
||||||
|
token = Token(user=self.user)
|
||||||
|
token.key = "" # Explicitly clear the key
|
||||||
|
token.save()
|
||||||
|
|
||||||
|
# Verify the key was generated
|
||||||
|
self.assertEqual(len(token.key), 40)
|
||||||
|
self.assertEqual(token.user, self.user)
|
||||||
|
|
||||||
|
# Verify it's saved in the database
|
||||||
|
token.refresh_from_db()
|
||||||
|
self.assertEqual(len(token.key), 40)
|
||||||
|
self.assertEqual(token.user, self.user)
|
||||||
|
|
||||||
|
def test_saving_existing_token_without_changes_does_not_alter_key(self):
|
||||||
|
"""
|
||||||
|
Ensure that calling save() on an existing token without modifications
|
||||||
|
does not change its key.
|
||||||
|
"""
|
||||||
|
token = Token.objects.create(user=self.user)
|
||||||
|
original_key = token.key
|
||||||
|
|
||||||
|
token.save()
|
||||||
|
self.assertEqual(token.key, original_key)
|
||||||
|
|
||||||
|
def test_generate_key_uses_secrets_module(self):
|
||||||
|
"""
|
||||||
|
Verify that `generate_key` correctly calls `secrets.token_hex`.
|
||||||
|
"""
|
||||||
|
with mock.patch('rest_framework.authtoken.models.secrets.token_hex') as mock_token_hex:
|
||||||
|
mock_token_hex.return_value = 'a_mocked_key_of_proper_length_0123456789'
|
||||||
|
key = Token.generate_key()
|
||||||
|
|
||||||
|
mock_token_hex.assert_called_once_with(20)
|
||||||
|
self.assertEqual(key, 'a_mocked_key_of_proper_length_0123456789')
|
||||||
|
|
||||||
|
|
||||||
class AuthTokenCommandTests(TestCase):
|
class AuthTokenCommandTests(TestCase):
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user