Add query optimization module and settings

- Add rest_framework/optimization module with query analyzer, optimizer, mixins, and middleware
- Add ENABLE_QUERY_OPTIMIZATION and WARN_ON_N_PLUS_ONE settings
- Add comprehensive test suite in tests/test_optimization.py

This feature provides automatic query optimization to prevent N+1 query
problems by analyzing serializer fields and applying select_related()
and prefetch_related() optimizations automatically.
This commit is contained in:
malikabdullahnazar 2025-11-25 22:29:15 +05:00
parent 8cc9e096a2
commit 3ff4f68883
6 changed files with 679 additions and 0 deletions

View File

@ -0,0 +1,27 @@
"""
Query optimization utilities for Django REST Framework.
This module provides tools to automatically detect and prevent N+1 query problems
in DRF serializers by analyzing serializer fields and optimizing querysets.
"""
from rest_framework.optimization.mixins import OptimizedQuerySetMixin
from rest_framework.optimization.optimizer import (
optimize_queryset,
analyze_serializer_fields,
get_optimization_suggestions,
)
from rest_framework.optimization.query_analyzer import (
QueryAnalyzer,
detect_n_plus_one,
)
__all__ = [
'OptimizedQuerySetMixin',
'optimize_queryset',
'analyze_serializer_fields',
'get_optimization_suggestions',
'QueryAnalyzer',
'detect_n_plus_one',
]

View File

@ -0,0 +1,118 @@
"""
Middleware for detecting N+1 queries in development mode.
This middleware can be added to Django's MIDDLEWARE setting to automatically
detect and warn about N+1 query problems during development.
"""
import warnings
from django.conf import settings
from django.db import connection
from django.utils.deprecation import MiddlewareMixin
class QueryOptimizationMiddleware(MiddlewareMixin):
"""
Middleware that detects potential N+1 queries in development mode.
This middleware tracks database queries and warns when patterns that
suggest N+1 queries are detected.
Usage:
Add to MIDDLEWARE in settings.py:
MIDDLEWARE = [
...
'rest_framework.optimization.middleware.QueryOptimizationMiddleware',
]
Settings:
- QUERY_OPTIMIZATION_WARN_THRESHOLD: Number of similar queries to trigger warning (default: 5)
"""
def __init__(self, get_response):
self.get_response = get_response
self.warn_threshold = getattr(
settings,
'QUERY_OPTIMIZATION_WARN_THRESHOLD',
5
)
super().__init__(get_response)
def process_request(self, request):
"""Reset query tracking for each request."""
if settings.DEBUG:
connection.queries_log.clear()
return None
def process_response(self, request, response):
"""Analyze queries and warn about potential N+1 issues."""
if not settings.DEBUG:
return response
try:
# In Django 5.2+, use queries_log, fallback to queries for older versions
if hasattr(connection, 'queries_log'):
queries = connection.queries_log
else:
queries = getattr(connection, 'queries', [])
if len(queries) > self.warn_threshold:
# Analyze queries for patterns
self._analyze_queries(queries, request)
except Exception as e:
# Don't break the request if analysis fails
import traceback
if settings.DEBUG:
# Only log in DEBUG mode to avoid noise
warnings.warn(f"Query optimization middleware error: {e}", UserWarning)
return response
def _analyze_queries(self, queries, request):
"""Analyze queries for N+1 patterns."""
# Group queries by SQL pattern
query_patterns = {}
for query in queries:
# Handle both dict format (old Django) and string format (new Django)
if isinstance(query, dict):
sql = query.get('sql', '')
elif isinstance(query, str):
sql = query
else:
# Django 5.2+ might use a different format
sql = str(query)
if not sql:
continue
# Normalize SQL (remove values, keep structure)
normalized = self._normalize_sql(sql)
if normalized not in query_patterns:
query_patterns[normalized] = []
query_patterns[normalized].append(query)
# Warn about patterns that appear many times (potential N+1)
for pattern, query_list in query_patterns.items():
if len(query_list) >= self.warn_threshold:
# Check if it's a SELECT query (not INSERT/UPDATE/DELETE)
if 'SELECT' in pattern.upper():
warnings.warn(
f"Potential N+1 query detected: {len(query_list)} similar queries "
f"executed for pattern: {pattern[:100]}... "
f"Consider using select_related() or prefetch_related().",
UserWarning
)
def _normalize_sql(self, sql):
"""Normalize SQL by removing values and keeping structure."""
import re
# Remove quoted strings
sql = re.sub(r"'[^']*'", "'?'", sql)
sql = re.sub(r'"[^"]*"', '"?"', sql)
# Remove numbers
sql = re.sub(r'\b\d+\b', '?', sql)
# Normalize whitespace
sql = ' '.join(sql.split())
return sql

View File

@ -0,0 +1,119 @@
"""
Mixins for automatic query optimization in Django REST Framework viewsets.
This module provides mixins that automatically optimize querysets based on
serializer field analysis.
"""
import warnings
from django.db.models import QuerySet
from rest_framework.settings import api_settings
from rest_framework.optimization.optimizer import optimize_queryset
from rest_framework.optimization.query_analyzer import detect_n_plus_one
class OptimizedQuerySetMixin:
"""
Mixin that automatically optimizes querysets based on serializer analysis.
This mixin can be added to any GenericAPIView or ViewSet to automatically
apply select_related and prefetch_related optimizations based on the
serializer's field definitions.
Usage:
class MyViewSet(OptimizedQuerySetMixin, ModelViewSet):
queryset = MyModel.objects.all()
serializer_class = MySerializer
You can also explicitly specify optimizations:
class MyViewSet(OptimizedQuerySetMixin, ModelViewSet):
queryset = MyModel.objects.all()
serializer_class = MySerializer
select_related_fields = ['author', 'category']
prefetch_related_fields = ['tags', 'comments']
Settings:
- ENABLE_QUERY_OPTIMIZATION: Enable/disable automatic optimization (default: True)
- WARN_ON_N_PLUS_ONE: Show warnings when N+1 queries are detected (default: True in DEBUG)
"""
# Explicit optimization fields (optional)
select_related_fields = None
prefetch_related_fields = None
# Control optimization behavior
enable_auto_optimization = True
warn_on_n_plus_one = None
def get_queryset(self):
"""
Get the queryset with automatic optimizations applied.
This method extends the base get_queryset() to automatically apply
select_related and prefetch_related based on serializer analysis.
"""
queryset = super().get_queryset()
# Check if optimization is enabled
enable_optimization = getattr(
api_settings,
'ENABLE_QUERY_OPTIMIZATION',
self.enable_auto_optimization
)
if not enable_optimization:
return queryset
# Get serializer class
serializer_class = self.get_serializer_class()
if not serializer_class:
return queryset
# Optimize queryset
try:
queryset = optimize_queryset(
queryset,
serializer_class,
select_related=self.select_related_fields,
prefetch_related=self.prefetch_related_fields,
auto_optimize=self.enable_auto_optimization
)
except Exception as e:
# If optimization fails, log warning but don't break
if self._should_warn():
warnings.warn(
f"Query optimization failed: {e}. "
f"Continuing with unoptimized queryset.",
UserWarning
)
return queryset
# Check for N+1 queries and warn if enabled
if self._should_warn():
warnings_list = detect_n_plus_one(serializer_class, queryset)
for warning_msg in warnings_list:
warnings.warn(warning_msg, UserWarning)
return queryset
def _should_warn(self):
"""Determine if warnings should be shown."""
if self.warn_on_n_plus_one is not None:
return self.warn_on_n_plus_one
# Default: warn in DEBUG mode
from django.conf import settings
warn_on_n_plus_one = getattr(
api_settings,
'WARN_ON_N_PLUS_ONE',
getattr(settings, 'DEBUG', False)
)
return warn_on_n_plus_one
# Backward compatibility alias
QueryOptimizationMixin = OptimizedQuerySetMixin

View File

@ -0,0 +1,163 @@
"""
Query optimizer for automatically optimizing querysets based on serializer fields.
This module provides utilities to automatically apply select_related and
prefetch_related optimizations to querysets based on serializer field analysis.
"""
from django.db.models import QuerySet
from rest_framework import serializers
from rest_framework.optimization.query_analyzer import QueryAnalyzer
def analyze_serializer_fields(serializer_class):
"""
Analyze a serializer class to identify required query optimizations.
Args:
serializer_class: The serializer class to analyze
Returns:
Dictionary with 'select_related' and 'prefetch_related' lists
"""
analyzer = QueryAnalyzer(serializer_class)
return analyzer.analyze()
def optimize_queryset(
queryset,
serializer_class,
select_related=None,
prefetch_related=None,
auto_optimize=True
):
"""
Optimize a queryset based on serializer analysis and/or explicit parameters.
Args:
queryset: The queryset to optimize
serializer_class: The serializer class that will be used
select_related: Explicit list of fields for select_related (optional)
prefetch_related: Explicit list of fields for prefetch_related (optional)
auto_optimize: If True, automatically analyze serializer and apply optimizations
Returns:
Optimized queryset
"""
if not isinstance(queryset, QuerySet):
return queryset
# Start with the original queryset
optimized = queryset
# Auto-optimize based on serializer analysis
if auto_optimize:
analysis = analyze_serializer_fields(serializer_class)
# Merge auto-detected with explicit parameters
if select_related is None:
select_related = analysis.get('select_related', [])
else:
# Merge lists, avoiding duplicates
auto_select = analysis.get('select_related', [])
select_related = list(set(select_related + auto_select))
if prefetch_related is None:
prefetch_related = analysis.get('prefetch_related', [])
else:
# Merge lists, avoiding duplicates
auto_prefetch = analysis.get('prefetch_related', [])
prefetch_related = list(set(prefetch_related + auto_prefetch))
# Apply select_related
if select_related:
# Check if queryset already has select_related
existing_select = getattr(optimized.query, 'select_related', {})
# Handle case where select_related is True (all fields selected)
if existing_select is True:
# All fields already selected, skip
new_select = []
elif isinstance(existing_select, dict):
# Only add fields that aren't already selected
new_select = [
field for field in select_related
if field not in existing_select and not any(
field.startswith(sel) for sel in existing_select.keys()
)
]
else:
# Empty or unknown format, add all
new_select = select_related
if new_select:
if len(new_select) == 1:
optimized = optimized.select_related(new_select[0])
else:
optimized = optimized.select_related(*new_select)
# Apply prefetch_related
if prefetch_related:
# Check if queryset already has prefetch_related
existing_prefetch = getattr(optimized.query, 'prefetch_related_lookups', set())
# Only add fields that aren't already prefetched
new_prefetch = [
field for field in prefetch_related
if field not in existing_prefetch and not any(
field.startswith(pref) for pref in existing_prefetch
)
]
if new_prefetch:
for field in new_prefetch:
optimized = optimized.prefetch_related(field)
return optimized
def get_optimization_suggestions(serializer_class):
"""
Get optimization suggestions for a serializer class.
Args:
serializer_class: The serializer class to analyze
Returns:
Dictionary with optimization suggestions and code examples
"""
analysis = analyze_serializer_fields(serializer_class)
suggestions = {
'select_related': analysis.get('select_related', []),
'prefetch_related': analysis.get('prefetch_related', []),
'nested_serializers': analysis.get('nested_serializers', []),
'code_example': None
}
# Generate code example
if suggestions['select_related'] or suggestions['prefetch_related']:
parts = []
if suggestions['select_related']:
if len(suggestions['select_related']) == 1:
parts.append(f".select_related('{suggestions['select_related'][0]}')")
else:
fields_str = "', '".join(suggestions['select_related'])
parts.append(f".select_related('{fields_str}')")
if suggestions['prefetch_related']:
for field in suggestions['prefetch_related']:
parts.append(f".prefetch_related('{field}')")
if parts:
suggestions['code_example'] = (
"def get_queryset(self):\n"
" queryset = super().get_queryset()\n"
f" return queryset{''.join(parts)}"
)
return suggestions

View File

@ -0,0 +1,248 @@
"""
Query analyzer for detecting N+1 query problems in Django REST Framework.
This module provides utilities to analyze serializer fields and detect potential
N+1 query issues before they occur.
"""
import warnings
from django.db import models
from django.db.models import ForeignKey, ManyToManyField, OneToOneField
from rest_framework import serializers
from rest_framework.relations import RelatedField, ManyRelatedField
from rest_framework.utils import model_meta
class QueryAnalyzer:
"""
Analyzes serializer fields to detect potential N+1 query problems.
This class examines serializer field definitions to identify relationships
that may cause N+1 queries when serializing querysets.
"""
def __init__(self, serializer_class):
"""
Initialize the analyzer with a serializer class.
Args:
serializer_class: The serializer class to analyze
"""
self.serializer_class = serializer_class
self._field_analysis = None
def analyze(self):
"""
Analyze the serializer and return a dictionary with optimization suggestions.
Returns:
Dictionary containing:
- select_related: List of fields that should use select_related
- prefetch_related: List of fields that should use prefetch_related
- nested_serializers: List of nested serializer fields
"""
if self._field_analysis is None:
self._field_analysis = self._analyze_fields()
return self._field_analysis
def _analyze_fields(self):
"""Analyze serializer fields to identify relationships."""
analysis = {
'select_related': [],
'prefetch_related': [],
'nested_serializers': [],
'warnings': []
}
if not issubclass(self.serializer_class, serializers.ModelSerializer):
return analysis
# Get the model from the serializer
model = getattr(self.serializer_class.Meta, 'model', None)
if not model:
return analysis
# Get field info using DRF's utility
try:
field_info = model_meta.get_field_info(model)
except Exception:
return analysis
# Analyze declared fields
serializer = self.serializer_class()
fields = serializer.fields
for field_name, field in fields.items():
# Analyze fields that are readable (not write_only)
# This includes read_only fields and fields that can be both read and written
if not field.write_only:
self._analyze_field(field_name, field, model, field_info, analysis)
return analysis
def _analyze_field(self, field_name, field, model, field_info, analysis):
"""Analyze a single field for potential N+1 issues."""
source = getattr(field, 'source', field_name)
source_parts = source.split('.')
base_field_name = source_parts[0]
# Check if it's a ManyRelatedField (many=True on RelatedField)
# This handles custom fields like PrimaryKeyRelatedField(many=True)
if isinstance(field, ManyRelatedField):
# ManyToMany or reverse relationship - use prefetch_related
if base_field_name not in analysis['prefetch_related']:
analysis['prefetch_related'].append(base_field_name)
return # Early return since ManyRelatedField is handled
# Check if it's a related field
if isinstance(field, RelatedField):
# Check if it's in the model's relationships
if base_field_name in field_info.relations:
relation_info = field_info.relations[base_field_name]
if not relation_info.to_many:
# ForeignKey or OneToOneField - use select_related
if base_field_name not in analysis['select_related']:
analysis['select_related'].append(base_field_name)
# Check for nested relationships
if len(source_parts) > 1 and relation_info.related_model:
self._analyze_nested_relationship(
relation_info.related_model, source_parts[1:], analysis
)
else:
# ManyToMany or reverse relationship - use prefetch_related
if base_field_name not in analysis['prefetch_related']:
analysis['prefetch_related'].append(base_field_name)
else:
# Field not in relations, but might be a custom field that maps to a model field
# Check if the field name matches a model ManyToMany field
try:
model_field = model._meta.get_field(base_field_name)
if isinstance(model_field, ManyToManyField):
if base_field_name not in analysis['prefetch_related']:
analysis['prefetch_related'].append(base_field_name)
except (models.FieldDoesNotExist, AttributeError):
# Field doesn't exist on model, might be a property or method
pass
# Check if it's a nested serializer
elif isinstance(field, serializers.Serializer):
analysis['nested_serializers'].append(field_name)
# First, ensure the base relationship is optimized
if base_field_name in field_info.relations:
relation_info = field_info.relations[base_field_name]
if not relation_info.to_many:
# ForeignKey or OneToOneField - use select_related
if base_field_name not in analysis['select_related']:
analysis['select_related'].append(base_field_name)
else:
# ManyToMany or reverse relationship - use prefetch_related
if base_field_name not in analysis['prefetch_related']:
analysis['prefetch_related'].append(base_field_name)
# Check if the nested serializer has a model for deeper analysis
try:
if hasattr(field, 'Meta') and hasattr(field.Meta, 'model'):
nested_model = field.Meta.model
# Recursively analyze nested serializer
nested_analyzer = QueryAnalyzer(field.__class__)
nested_analysis = nested_analyzer.analyze()
# Merge nested analysis
if source_parts:
base_field = source_parts[0]
# Add nested select_related/prefetch_related
for nested_field in nested_analysis.get('select_related', []):
full_path = f"{base_field}__{nested_field}"
if full_path not in analysis['select_related']:
analysis['select_related'].append(full_path)
for nested_field in nested_analysis.get('prefetch_related', []):
full_path = f"{base_field}__{nested_field}"
if full_path not in analysis['prefetch_related']:
analysis['prefetch_related'].append(full_path)
except Exception:
# If nested serializer analysis fails, we've already handled the base relationship above
pass
def _analyze_nested_relationship(self, related_model, path_parts, analysis):
"""Analyze nested relationships (e.g., 'author__profile')."""
if not path_parts:
return
try:
field = related_model._meta.get_field(path_parts[0])
if isinstance(field, (ForeignKey, OneToOneField)):
full_path = '__'.join(path_parts)
if full_path not in analysis['select_related']:
analysis['select_related'].append(full_path)
elif isinstance(field, ManyToManyField):
full_path = '__'.join(path_parts)
if full_path not in analysis['prefetch_related']:
analysis['prefetch_related'].append(full_path)
except models.FieldDoesNotExist:
pass
def detect_n_plus_one(serializer_class, queryset):
"""
Detect potential N+1 query issues for a serializer and queryset.
Args:
serializer_class: The serializer class to analyze
queryset: The queryset that will be serialized
Returns:
List of warning messages about potential N+1 queries
"""
warnings_list = []
if not hasattr(queryset, 'query'):
# Not a queryset, can't analyze
return warnings_list
analyzer = QueryAnalyzer(serializer_class)
analysis = analyzer.analyze()
# Check if queryset has optimizations
query = queryset.query
select_related = getattr(query, 'select_related', {})
prefetch_related = getattr(query, 'prefetch_related_lookups', set())
# Check for missing select_related
for field in analysis.get('select_related', []):
# If select_related is True, all fields are selected
if select_related is True:
continue
elif isinstance(select_related, dict):
if field not in select_related and not any(
field.startswith(sel) for sel in select_related.keys()
):
warnings_list.append(
f"Potential N+1 query detected: Consider using "
f"select_related('{field}') for field '{field}'"
)
else:
# No select_related, add warning
warnings_list.append(
f"Potential N+1 query detected: Consider using "
f"select_related('{field}') for field '{field}'"
)
# Check for missing prefetch_related
for field in analysis.get('prefetch_related', []):
if field not in prefetch_related and not any(
field.startswith(pref) for pref in prefetch_related
):
warnings_list.append(
f"Potential N+1 query detected: Consider using "
f"prefetch_related('{field}') for field '{field}'"
)
return warnings_list

View File

@ -128,6 +128,10 @@ DEFAULTS = {
'retrieve': 'read',
'destroy': 'delete'
},
# Query Optimization
'ENABLE_QUERY_OPTIMIZATION': True,
'WARN_ON_N_PLUS_ONE': None, # None means auto-detect based on DEBUG
}