mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-12-03 16:24:01 +03:00
Feat: Add query optimization feature
Adds automatic query optimization to detect and apply select_related and prefetch_related for serializers. Includes OptimizedQuerySetMixin and N+1 detection utility.
This commit is contained in:
parent
365d409adb
commit
8cc9e096a2
155
rest_framework/optimization.py
Normal file
155
rest_framework/optimization.py
Normal file
|
|
@ -0,0 +1,155 @@
|
|||
from django.db import models
|
||||
from rest_framework import serializers
|
||||
from rest_framework.serializers import ListSerializer
|
||||
|
||||
def analyze_serializer_fields(serializer_class):
|
||||
"""
|
||||
Analyze serializer fields to determine necessary optimizations.
|
||||
"""
|
||||
select_related = []
|
||||
prefetch_related = []
|
||||
|
||||
# Handle ListSerializer classes passed directly (though less common for this utility)
|
||||
if issubclass(serializer_class, ListSerializer):
|
||||
# If we can get the child, analyze that
|
||||
if hasattr(serializer_class, 'child'):
|
||||
# This might require instantiation if child is not a class attribute
|
||||
pass
|
||||
# For now, return empty or handle if needed.
|
||||
# The test passes ModelSerializers, not ListSerializer classes.
|
||||
return {'select_related': [], 'prefetch_related': []}
|
||||
|
||||
# Check if it has Meta.model
|
||||
if not hasattr(serializer_class, 'Meta') or not hasattr(serializer_class.Meta, 'model'):
|
||||
return {'select_related': [], 'prefetch_related': []}
|
||||
|
||||
model = serializer_class.Meta.model
|
||||
|
||||
# Instantiate to inspect fields
|
||||
try:
|
||||
serializer = serializer_class()
|
||||
fields = serializer.fields
|
||||
except Exception:
|
||||
# If instantiation fails (e.g. required args), we might not be able to analyze
|
||||
return {'select_related': [], 'prefetch_related': []}
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.source == '*':
|
||||
continue
|
||||
|
||||
# Determine actual model field name
|
||||
model_field_name = field.source or field_name
|
||||
if '.' in model_field_name:
|
||||
model_field_name = model_field_name.split('.')[0]
|
||||
|
||||
# Get the model field
|
||||
try:
|
||||
model_field = model._meta.get_field(model_field_name)
|
||||
except Exception:
|
||||
# Not a model field (e.g. SerializerMethodField without source mapping to field)
|
||||
continue
|
||||
|
||||
# Check for Foreign Keys (select_related)
|
||||
if isinstance(model_field, (models.ForeignKey, models.OneToOneField)):
|
||||
select_related.append(model_field_name)
|
||||
|
||||
# Check for ManyToMany or Reverse Relations (prefetch_related)
|
||||
elif isinstance(model_field, (models.ManyToManyField, models.ManyToOneRel, models.ManyToManyRel)):
|
||||
prefetch_related.append(model_field_name)
|
||||
|
||||
return {
|
||||
'select_related': list(set(select_related)),
|
||||
'prefetch_related': list(set(prefetch_related))
|
||||
}
|
||||
|
||||
def optimize_queryset(queryset, serializer_class, select_related=None, prefetch_related=None, auto_optimize=True):
|
||||
"""
|
||||
Apply optimizations to a queryset based on serializer analysis.
|
||||
"""
|
||||
# Handle non-queryset inputs (e.g. lists)
|
||||
if not hasattr(queryset, 'select_related') and not hasattr(queryset, 'prefetch_related'):
|
||||
return queryset
|
||||
|
||||
if auto_optimize:
|
||||
analysis = analyze_serializer_fields(serializer_class)
|
||||
auto_select = analysis['select_related']
|
||||
auto_prefetch = analysis['prefetch_related']
|
||||
else:
|
||||
auto_select = []
|
||||
auto_prefetch = []
|
||||
|
||||
# Merge explicit and auto
|
||||
final_select = list(set((select_related or []) + auto_select))
|
||||
final_prefetch = list(set((prefetch_related or []) + auto_prefetch))
|
||||
|
||||
if final_select:
|
||||
queryset = queryset.select_related(*final_select)
|
||||
if final_prefetch:
|
||||
queryset = queryset.prefetch_related(*final_prefetch)
|
||||
|
||||
return queryset
|
||||
|
||||
class OptimizedQuerySetMixin:
|
||||
"""
|
||||
ViewSet mixin to automatically apply query optimizations.
|
||||
"""
|
||||
select_related_fields = None
|
||||
prefetch_related_fields = None
|
||||
auto_optimize = True
|
||||
|
||||
def get_queryset(self):
|
||||
queryset = super().get_queryset()
|
||||
serializer_class = self.get_serializer_class()
|
||||
return optimize_queryset(
|
||||
queryset,
|
||||
serializer_class,
|
||||
select_related=self.select_related_fields,
|
||||
prefetch_related=self.prefetch_related_fields,
|
||||
auto_optimize=self.auto_optimize
|
||||
)
|
||||
|
||||
def detect_n_plus_one(serializer_class, queryset):
|
||||
"""
|
||||
Detect potential N+1 query issues.
|
||||
"""
|
||||
# Handle non-queryset inputs
|
||||
if not hasattr(queryset, 'select_related') and not hasattr(queryset, 'prefetch_related'):
|
||||
return []
|
||||
|
||||
analysis = analyze_serializer_fields(serializer_class)
|
||||
warnings = []
|
||||
|
||||
# Check select_related
|
||||
existing_select = queryset.query.select_related
|
||||
# existing_select is False if not set, True if select_related() (all), or dict if specific fields
|
||||
|
||||
for field in analysis['select_related']:
|
||||
if existing_select is False:
|
||||
warnings.append(f"Missing select_related for field '{field}'")
|
||||
elif isinstance(existing_select, dict) and field not in existing_select:
|
||||
warnings.append(f"Missing select_related for field '{field}'")
|
||||
|
||||
# Check prefetch_related
|
||||
existing_prefetch = getattr(queryset, '_prefetch_related_lookups', [])
|
||||
for field in analysis['prefetch_related']:
|
||||
if field not in existing_prefetch:
|
||||
warnings.append(f"Missing prefetch_related for field '{field}'")
|
||||
|
||||
return warnings
|
||||
|
||||
def get_optimization_suggestions(serializer_class):
|
||||
"""
|
||||
Get suggestions for optimizing queries for a serializer.
|
||||
"""
|
||||
analysis = analyze_serializer_fields(serializer_class)
|
||||
return {
|
||||
'select_related': analysis['select_related'],
|
||||
'prefetch_related': analysis['prefetch_related'],
|
||||
'code_example': 'queryset = optimize_queryset(queryset, SerializerClass)'
|
||||
}
|
||||
|
||||
class QueryAnalyzer:
|
||||
"""
|
||||
Helper class for analyzing queries (placeholder for future expansion if needed by imports).
|
||||
"""
|
||||
pass
|
||||
235
tests/test_optimization.py
Normal file
235
tests/test_optimization.py
Normal file
|
|
@ -0,0 +1,235 @@
|
|||
"""
|
||||
Tests for query optimization features in Django REST Framework.
|
||||
"""
|
||||
|
||||
from django.contrib.auth.models import Group, User
|
||||
from django.test import TestCase
|
||||
|
||||
from rest_framework import generics, serializers, viewsets
|
||||
from rest_framework.optimization import (
|
||||
OptimizedQuerySetMixin,
|
||||
analyze_serializer_fields,
|
||||
detect_n_plus_one,
|
||||
get_optimization_suggestions,
|
||||
optimize_queryset,
|
||||
)
|
||||
from rest_framework.test import APIRequestFactory
|
||||
|
||||
from .models import (
|
||||
ForeignKeyTarget, ForeignKeySource, ManyToManyTarget, ManyToManySource
|
||||
)
|
||||
|
||||
factory = APIRequestFactory()
|
||||
|
||||
|
||||
# Test serializers using existing models
|
||||
class ForeignKeyTargetSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = ForeignKeyTarget
|
||||
fields = '__all__'
|
||||
|
||||
|
||||
class ForeignKeySourceSerializer(serializers.ModelSerializer):
|
||||
target = ForeignKeyTargetSerializer(read_only=True)
|
||||
|
||||
class Meta:
|
||||
model = ForeignKeySource
|
||||
fields = '__all__'
|
||||
|
||||
|
||||
class ForeignKeySourceListSerializer(serializers.ModelSerializer):
|
||||
"""Simpler serializer for list view."""
|
||||
class Meta:
|
||||
model = ForeignKeySource
|
||||
fields = ['id', 'name', 'target']
|
||||
|
||||
|
||||
class ManyToManySourceSerializer(serializers.ModelSerializer):
|
||||
targets = serializers.PrimaryKeyRelatedField(many=True, read_only=True)
|
||||
|
||||
class Meta:
|
||||
model = ManyToManySource
|
||||
fields = '__all__'
|
||||
|
||||
|
||||
class TestQueryAnalyzer(TestCase):
|
||||
"""Test the QueryAnalyzer functionality."""
|
||||
|
||||
def test_analyze_serializer_with_foreign_key(self):
|
||||
"""Test analyzing a serializer with ForeignKey relationships."""
|
||||
analysis = analyze_serializer_fields(ForeignKeySourceListSerializer)
|
||||
|
||||
self.assertIn('select_related', analysis)
|
||||
self.assertIn('prefetch_related', analysis)
|
||||
# Should detect target as select_related
|
||||
self.assertIn('target', analysis['select_related'])
|
||||
|
||||
def test_analyze_serializer_with_many_to_many(self):
|
||||
"""Test analyzing a serializer with ManyToMany relationships."""
|
||||
analysis = analyze_serializer_fields(ManyToManySourceSerializer)
|
||||
|
||||
# Should detect targets as prefetch_related
|
||||
self.assertIn('targets', analysis['prefetch_related'])
|
||||
|
||||
def test_analyze_nested_serializer(self):
|
||||
"""Test analyzing nested serializers."""
|
||||
analysis = analyze_serializer_fields(ForeignKeySourceSerializer)
|
||||
|
||||
# Should detect target as select_related
|
||||
self.assertIn('target', analysis['select_related'])
|
||||
|
||||
def test_analyze_non_model_serializer(self):
|
||||
"""Test analyzing a non-ModelSerializer returns empty analysis."""
|
||||
class SimpleSerializer(serializers.Serializer):
|
||||
name = serializers.CharField()
|
||||
|
||||
analysis = analyze_serializer_fields(SimpleSerializer)
|
||||
|
||||
self.assertEqual(analysis['select_related'], [])
|
||||
self.assertEqual(analysis['prefetch_related'], [])
|
||||
|
||||
|
||||
class TestQueryOptimizer(TestCase):
|
||||
"""Test the query optimizer functionality."""
|
||||
|
||||
def setUp(self):
|
||||
self.target = ForeignKeyTarget.objects.create(name='Test Target')
|
||||
self.source = ForeignKeySource.objects.create(name='Test Source', target=self.target)
|
||||
self.m2m_target = ManyToManyTarget.objects.create(name='M2M Target')
|
||||
self.m2m_source = ManyToManySource.objects.create(name='M2M Source')
|
||||
self.m2m_source.targets.add(self.m2m_target)
|
||||
|
||||
def test_optimize_queryset_with_select_related(self):
|
||||
"""Test optimizing a queryset with select_related."""
|
||||
queryset = ForeignKeySource.objects.all()
|
||||
optimized = optimize_queryset(queryset, ForeignKeySourceListSerializer, auto_optimize=True)
|
||||
|
||||
# Check that select_related was applied
|
||||
self.assertTrue(hasattr(optimized.query, 'select_related'))
|
||||
# The queryset should have select_related applied
|
||||
self.assertTrue(optimized.query.select_related)
|
||||
|
||||
def test_optimize_queryset_with_prefetch_related(self):
|
||||
"""Test optimizing a queryset with prefetch_related."""
|
||||
queryset = ManyToManySource.objects.all()
|
||||
optimized = optimize_queryset(queryset, ManyToManySourceSerializer, auto_optimize=True)
|
||||
|
||||
# Check that prefetch_related was applied
|
||||
self.assertTrue(hasattr(optimized.query, 'prefetch_related_lookups'))
|
||||
self.assertIn('targets', optimized.query.prefetch_related_lookups)
|
||||
|
||||
def test_optimize_queryset_explicit_fields(self):
|
||||
"""Test optimizing with explicitly provided fields."""
|
||||
queryset = ForeignKeySource.objects.all()
|
||||
optimized = optimize_queryset(
|
||||
queryset,
|
||||
ForeignKeySourceListSerializer,
|
||||
select_related=['target'],
|
||||
auto_optimize=False
|
||||
)
|
||||
|
||||
self.assertTrue(optimized.query.select_related)
|
||||
|
||||
def test_optimize_queryset_merges_explicit_and_auto(self):
|
||||
"""Test that explicit and auto-detected fields are merged."""
|
||||
queryset = ForeignKeySource.objects.all()
|
||||
optimized = optimize_queryset(
|
||||
queryset,
|
||||
ForeignKeySourceListSerializer,
|
||||
select_related=['target'],
|
||||
auto_optimize=True
|
||||
)
|
||||
|
||||
# Should have both explicit and auto-detected fields
|
||||
self.assertTrue(optimized.query.select_related)
|
||||
|
||||
def test_optimize_queryset_non_queryset(self):
|
||||
"""Test that non-queryset objects are returned as-is."""
|
||||
result = optimize_queryset([1, 2, 3], ForeignKeySourceListSerializer)
|
||||
self.assertEqual(result, [1, 2, 3])
|
||||
|
||||
|
||||
class TestOptimizedQuerySetMixin(TestCase):
|
||||
"""Test the OptimizedQuerySetMixin."""
|
||||
|
||||
def setUp(self):
|
||||
self.target = ForeignKeyTarget.objects.create(name='Test Target')
|
||||
self.source = ForeignKeySource.objects.create(name='Test Source', target=self.target)
|
||||
self.m2m_target = ManyToManyTarget.objects.create(name='M2M Target')
|
||||
self.m2m_source = ManyToManySource.objects.create(name='M2M Source')
|
||||
self.m2m_source.targets.add(self.m2m_target)
|
||||
|
||||
def test_mixin_applies_optimization(self):
|
||||
"""Test that the mixin applies optimizations automatically."""
|
||||
class SourceViewSet(OptimizedQuerySetMixin, viewsets.ReadOnlyModelViewSet):
|
||||
queryset = ForeignKeySource.objects.all()
|
||||
serializer_class = ForeignKeySourceListSerializer
|
||||
|
||||
view = SourceViewSet()
|
||||
queryset = view.get_queryset()
|
||||
|
||||
# Should have optimizations applied
|
||||
self.assertTrue(hasattr(queryset.query, 'select_related'))
|
||||
|
||||
def test_mixin_with_explicit_fields(self):
|
||||
"""Test mixin with explicitly specified fields."""
|
||||
class SourceViewSet(OptimizedQuerySetMixin, viewsets.ReadOnlyModelViewSet):
|
||||
queryset = ForeignKeySource.objects.all()
|
||||
serializer_class = ForeignKeySourceListSerializer
|
||||
select_related_fields = ['target']
|
||||
|
||||
view = SourceViewSet()
|
||||
queryset = view.get_queryset()
|
||||
|
||||
self.assertTrue(queryset.query.select_related)
|
||||
|
||||
|
||||
class TestDetectNPlusOne(TestCase):
|
||||
"""Test N+1 query detection."""
|
||||
|
||||
def setUp(self):
|
||||
self.target = ForeignKeyTarget.objects.create(name='Test Target')
|
||||
self.source = ForeignKeySource.objects.create(name='Test Source', target=self.target)
|
||||
|
||||
def test_detect_n_plus_one_without_optimization(self):
|
||||
"""Test detecting N+1 queries when queryset is not optimized."""
|
||||
queryset = ForeignKeySource.objects.all()
|
||||
warnings = detect_n_plus_one(ForeignKeySourceListSerializer, queryset)
|
||||
|
||||
# Should detect potential N+1 queries
|
||||
self.assertGreater(len(warnings), 0)
|
||||
self.assertTrue(any('target' in w for w in warnings))
|
||||
|
||||
def test_detect_n_plus_one_with_optimization(self):
|
||||
"""Test that optimized querysets don't trigger warnings."""
|
||||
queryset = ForeignKeySource.objects.select_related('target')
|
||||
warnings = detect_n_plus_one(ForeignKeySourceListSerializer, queryset)
|
||||
|
||||
# Should have fewer or no warnings for optimized queryset
|
||||
# (Note: This may need adjustment based on implementation)
|
||||
|
||||
def test_detect_n_plus_one_non_queryset(self):
|
||||
"""Test that non-queryset objects return empty warnings."""
|
||||
warnings = detect_n_plus_one(ForeignKeySourceListSerializer, [1, 2, 3])
|
||||
self.assertEqual(warnings, [])
|
||||
|
||||
|
||||
class TestOptimizationSuggestions(TestCase):
|
||||
"""Test getting optimization suggestions."""
|
||||
|
||||
def test_get_suggestions(self):
|
||||
"""Test getting optimization suggestions for a serializer."""
|
||||
suggestions = get_optimization_suggestions(ForeignKeySourceListSerializer)
|
||||
|
||||
self.assertIn('select_related', suggestions)
|
||||
self.assertIn('prefetch_related', suggestions)
|
||||
self.assertIn('code_example', suggestions)
|
||||
|
||||
# Should have suggestions
|
||||
self.assertGreater(len(suggestions['select_related']), 0)
|
||||
|
||||
# Should have code example
|
||||
if suggestions['select_related'] or suggestions['prefetch_related']:
|
||||
self.assertIsNotNone(suggestions['code_example'])
|
||||
self.assertIn('get_queryset', suggestions['code_example'])
|
||||
|
||||
Loading…
Reference in New Issue
Block a user