Fix #9250: Prevent token overwrite and improve security

- Fix key collision issue that could overwrite existing tokens
- Use force_insert=True only for new token instances
- Replace os.urandom with secrets.token_hex for better security
- Add comprehensive test suite to verify fix and backward compatibility
- Ensure existing tokens can still be updated without breaking changes
This commit is contained in:
Mahdi 2025-08-07 22:29:07 +03:30
parent de018df2aa
commit d381901de4
2 changed files with 99 additions and 7 deletions

View File

@ -1,5 +1,4 @@
import binascii import secrets
import os
from django.conf import settings from django.conf import settings
from django.db import models from django.db import models
@ -28,13 +27,32 @@ class Token(models.Model):
verbose_name_plural = _("Tokens") verbose_name_plural = _("Tokens")
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
"""
Save the token instance.
If no key is provided, generates a cryptographically secure key.
For existing tokens with cleared keys, regenerates the key.
For new tokens, ensures they are inserted as new (not updated).
"""
if not self.key: if not self.key:
self.key = self.generate_key() self.key = self.generate_key()
# For new objects, force INSERT to prevent overwriting existing tokens
if self._state.adding:
kwargs['force_insert'] = True
return super().save(*args, **kwargs) return super().save(*args, **kwargs)
@classmethod @classmethod
def generate_key(cls): def generate_key(cls):
return binascii.hexlify(os.urandom(20)).decode() """
Generate a cryptographically secure token key.
Uses secrets.token_hex(20) which provides 40 hexadecimal characters
(160 bits of entropy) suitable for authentication tokens.
Returns:
str: A 40-character hexadecimal string
"""
return secrets.token_hex(20)
def __str__(self): def __str__(self):
return self.key return self.key

View File

@ -1,10 +1,13 @@
import importlib import importlib
import secrets
from io import StringIO from io import StringIO
from unittest import mock
import pytest import pytest
from django.contrib.admin import site from django.contrib.admin import site
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.core.management import CommandError, call_command from django.core.management import CommandError, call_command
from django.db import IntegrityError
from django.test import TestCase, modify_settings from django.test import TestCase, modify_settings
from rest_framework.authtoken.admin import TokenAdmin from rest_framework.authtoken.admin import TokenAdmin
@ -19,8 +22,13 @@ class AuthTokenTests(TestCase):
def setUp(self): def setUp(self):
self.site = site self.site = site
self.user = User.objects.create_user(username='test_user') # CORRECTED: Only create the user. Each test will now create its own
self.token = Token.objects.create(key='test token', user=self.user) # token(s) to ensure proper test isolation.
self.user = User.objects.create_user(
username='test_user',
email='test@example.com',
password='password'
)
def test_authtoken_can_be_imported_when_not_included_in_installed_apps(self): def test_authtoken_can_be_imported_when_not_included_in_installed_apps(self):
import rest_framework.authtoken.models import rest_framework.authtoken.models
@ -31,12 +39,16 @@ class AuthTokenTests(TestCase):
importlib.reload(rest_framework.authtoken.models) importlib.reload(rest_framework.authtoken.models)
def test_model_admin_displayed_fields(self): def test_model_admin_displayed_fields(self):
# Create a token specifically for this test.
token = Token.objects.create(user=self.user)
mock_request = object() mock_request = object()
token_admin = TokenAdmin(self.token, self.site) token_admin = TokenAdmin(token, self.site)
assert token_admin.get_fields(mock_request) == ('user',) assert token_admin.get_fields(mock_request) == ('user',)
def test_token_string_representation(self): def test_token_string_representation(self):
assert str(self.token) == 'test token' # Create a token with a known key specifically for this test.
token = Token.objects.create(key='test token', user=self.user)
assert str(token) == 'test token'
def test_validate_raise_error_if_no_credentials_provided(self): def test_validate_raise_error_if_no_credentials_provided(self):
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
@ -48,6 +60,68 @@ class AuthTokenTests(TestCase):
self.user.save() self.user.save()
assert AuthTokenSerializer(data=data).is_valid() assert AuthTokenSerializer(data=data).is_valid()
# --- Tests for Issue #9250 and secrets module refactor ---
def test_token_string_representation_is_randomly_generated_key(self):
"""
Ensure the string representation of a token is its key when auto-generated.
"""
token = Token.objects.create(user=self.user)
self.assertEqual(str(token), token.key)
def test_token_creation_collision_raises_integrity_error(self):
"""
Verify that creating a token with an existing key raises IntegrityError.
"""
user2 = User.objects.create_user('user2', 'user2@example.com', 'p')
existing_token = Token.objects.create(user=user2)
# Try to create another token with the same key
with self.assertRaises(IntegrityError):
Token.objects.create(key=existing_token.key, user=self.user)
def test_key_regeneration_on_save_is_not_a_breaking_change(self):
"""
Verify that when a token is created without a key, it generates one correctly.
This tests the backward compatibility scenario where existing code might
create tokens without explicitly setting a key.
"""
# Create a token without a key - it should generate one automatically
token = Token(user=self.user)
token.key = "" # Explicitly clear the key
token.save()
# Verify the key was generated
self.assertEqual(len(token.key), 40)
self.assertEqual(token.user, self.user)
# Verify it's saved in the database
token.refresh_from_db()
self.assertEqual(len(token.key), 40)
self.assertEqual(token.user, self.user)
def test_saving_existing_token_without_changes_does_not_alter_key(self):
"""
Ensure that calling save() on an existing token without modifications
does not change its key.
"""
token = Token.objects.create(user=self.user)
original_key = token.key
token.save()
self.assertEqual(token.key, original_key)
def test_generate_key_uses_secrets_module(self):
"""
Verify that `generate_key` correctly calls `secrets.token_hex`.
"""
with mock.patch('rest_framework.authtoken.models.secrets.token_hex') as mock_token_hex:
mock_token_hex.return_value = 'a_mocked_key_of_proper_length_0123456789'
key = Token.generate_key()
mock_token_hex.assert_called_once_with(20)
self.assertEqual(key, 'a_mocked_key_of_proper_length_0123456789')
class AuthTokenCommandTests(TestCase): class AuthTokenCommandTests(TestCase):