django-rest-framework/rest_framework/optimization/middleware.py
malikabdullahnazar 3ff4f68883 Add query optimization module and settings
- Add rest_framework/optimization module with query analyzer, optimizer, mixins, and middleware
- Add ENABLE_QUERY_OPTIMIZATION and WARN_ON_N_PLUS_ONE settings
- Add comprehensive test suite in tests/test_optimization.py

This feature provides automatic query optimization to prevent N+1 query
problems by analyzing serializer fields and applying select_related()
and prefetch_related() optimizations automatically.
2025-11-25 22:29:15 +05:00

119 lines
4.2 KiB
Python

"""
Middleware for detecting N+1 queries in development mode.
This middleware can be added to Django's MIDDLEWARE setting to automatically
detect and warn about N+1 query problems during development.
"""
import warnings
from django.conf import settings
from django.db import connection
from django.utils.deprecation import MiddlewareMixin
class QueryOptimizationMiddleware(MiddlewareMixin):
"""
Middleware that detects potential N+1 queries in development mode.
This middleware tracks database queries and warns when patterns that
suggest N+1 queries are detected.
Usage:
Add to MIDDLEWARE in settings.py:
MIDDLEWARE = [
...
'rest_framework.optimization.middleware.QueryOptimizationMiddleware',
]
Settings:
- QUERY_OPTIMIZATION_WARN_THRESHOLD: Number of similar queries to trigger warning (default: 5)
"""
def __init__(self, get_response):
self.get_response = get_response
self.warn_threshold = getattr(
settings,
'QUERY_OPTIMIZATION_WARN_THRESHOLD',
5
)
super().__init__(get_response)
def process_request(self, request):
"""Reset query tracking for each request."""
if settings.DEBUG:
connection.queries_log.clear()
return None
def process_response(self, request, response):
"""Analyze queries and warn about potential N+1 issues."""
if not settings.DEBUG:
return response
try:
# In Django 5.2+, use queries_log, fallback to queries for older versions
if hasattr(connection, 'queries_log'):
queries = connection.queries_log
else:
queries = getattr(connection, 'queries', [])
if len(queries) > self.warn_threshold:
# Analyze queries for patterns
self._analyze_queries(queries, request)
except Exception as e:
# Don't break the request if analysis fails
import traceback
if settings.DEBUG:
# Only log in DEBUG mode to avoid noise
warnings.warn(f"Query optimization middleware error: {e}", UserWarning)
return response
def _analyze_queries(self, queries, request):
"""Analyze queries for N+1 patterns."""
# Group queries by SQL pattern
query_patterns = {}
for query in queries:
# Handle both dict format (old Django) and string format (new Django)
if isinstance(query, dict):
sql = query.get('sql', '')
elif isinstance(query, str):
sql = query
else:
# Django 5.2+ might use a different format
sql = str(query)
if not sql:
continue
# Normalize SQL (remove values, keep structure)
normalized = self._normalize_sql(sql)
if normalized not in query_patterns:
query_patterns[normalized] = []
query_patterns[normalized].append(query)
# Warn about patterns that appear many times (potential N+1)
for pattern, query_list in query_patterns.items():
if len(query_list) >= self.warn_threshold:
# Check if it's a SELECT query (not INSERT/UPDATE/DELETE)
if 'SELECT' in pattern.upper():
warnings.warn(
f"Potential N+1 query detected: {len(query_list)} similar queries "
f"executed for pattern: {pattern[:100]}... "
f"Consider using select_related() or prefetch_related().",
UserWarning
)
def _normalize_sql(self, sql):
"""Normalize SQL by removing values and keeping structure."""
import re
# Remove quoted strings
sql = re.sub(r"'[^']*'", "'?'", sql)
sql = re.sub(r'"[^"]*"', '"?"', sql)
# Remove numbers
sql = re.sub(r'\b\d+\b', '?', sql)
# Normalize whitespace
sql = ' '.join(sql.split())
return sql