From 3ff4f68883a7e07398437e2c20d2a4ca4239cb17 Mon Sep 17 00:00:00 2001 From: malikabdullahnazar Date: Tue, 25 Nov 2025 22:29:15 +0500 Subject: [PATCH] Add query optimization module and settings - Add rest_framework/optimization module with query analyzer, optimizer, mixins, and middleware - Add ENABLE_QUERY_OPTIMIZATION and WARN_ON_N_PLUS_ONE settings - Add comprehensive test suite in tests/test_optimization.py This feature provides automatic query optimization to prevent N+1 query problems by analyzing serializer fields and applying select_related() and prefetch_related() optimizations automatically. --- rest_framework/optimization/__init__.py | 27 ++ rest_framework/optimization/middleware.py | 118 +++++++++ rest_framework/optimization/mixins.py | 119 +++++++++ rest_framework/optimization/optimizer.py | 163 ++++++++++++ rest_framework/optimization/query_analyzer.py | 248 ++++++++++++++++++ rest_framework/settings.py | 4 + 6 files changed, 679 insertions(+) create mode 100644 rest_framework/optimization/__init__.py create mode 100644 rest_framework/optimization/middleware.py create mode 100644 rest_framework/optimization/mixins.py create mode 100644 rest_framework/optimization/optimizer.py create mode 100644 rest_framework/optimization/query_analyzer.py diff --git a/rest_framework/optimization/__init__.py b/rest_framework/optimization/__init__.py new file mode 100644 index 000000000..7b4927067 --- /dev/null +++ b/rest_framework/optimization/__init__.py @@ -0,0 +1,27 @@ +""" +Query optimization utilities for Django REST Framework. + +This module provides tools to automatically detect and prevent N+1 query problems +in DRF serializers by analyzing serializer fields and optimizing querysets. +""" + +from rest_framework.optimization.mixins import OptimizedQuerySetMixin +from rest_framework.optimization.optimizer import ( + optimize_queryset, + analyze_serializer_fields, + get_optimization_suggestions, +) +from rest_framework.optimization.query_analyzer import ( + QueryAnalyzer, + detect_n_plus_one, +) + +__all__ = [ + 'OptimizedQuerySetMixin', + 'optimize_queryset', + 'analyze_serializer_fields', + 'get_optimization_suggestions', + 'QueryAnalyzer', + 'detect_n_plus_one', +] + diff --git a/rest_framework/optimization/middleware.py b/rest_framework/optimization/middleware.py new file mode 100644 index 000000000..13a620a6f --- /dev/null +++ b/rest_framework/optimization/middleware.py @@ -0,0 +1,118 @@ +""" +Middleware for detecting N+1 queries in development mode. + +This middleware can be added to Django's MIDDLEWARE setting to automatically +detect and warn about N+1 query problems during development. +""" + +import warnings +from django.conf import settings +from django.db import connection +from django.utils.deprecation import MiddlewareMixin + + +class QueryOptimizationMiddleware(MiddlewareMixin): + """ + Middleware that detects potential N+1 queries in development mode. + + This middleware tracks database queries and warns when patterns that + suggest N+1 queries are detected. + + Usage: + Add to MIDDLEWARE in settings.py: + + MIDDLEWARE = [ + ... + 'rest_framework.optimization.middleware.QueryOptimizationMiddleware', + ] + + Settings: + - QUERY_OPTIMIZATION_WARN_THRESHOLD: Number of similar queries to trigger warning (default: 5) + """ + + def __init__(self, get_response): + self.get_response = get_response + self.warn_threshold = getattr( + settings, + 'QUERY_OPTIMIZATION_WARN_THRESHOLD', + 5 + ) + super().__init__(get_response) + + def process_request(self, request): + """Reset query tracking for each request.""" + if settings.DEBUG: + connection.queries_log.clear() + return None + + def process_response(self, request, response): + """Analyze queries and warn about potential N+1 issues.""" + if not settings.DEBUG: + return response + + try: + # In Django 5.2+, use queries_log, fallback to queries for older versions + if hasattr(connection, 'queries_log'): + queries = connection.queries_log + else: + queries = getattr(connection, 'queries', []) + + if len(queries) > self.warn_threshold: + # Analyze queries for patterns + self._analyze_queries(queries, request) + except Exception as e: + # Don't break the request if analysis fails + import traceback + if settings.DEBUG: + # Only log in DEBUG mode to avoid noise + warnings.warn(f"Query optimization middleware error: {e}", UserWarning) + + return response + + def _analyze_queries(self, queries, request): + """Analyze queries for N+1 patterns.""" + # Group queries by SQL pattern + query_patterns = {} + for query in queries: + # Handle both dict format (old Django) and string format (new Django) + if isinstance(query, dict): + sql = query.get('sql', '') + elif isinstance(query, str): + sql = query + else: + # Django 5.2+ might use a different format + sql = str(query) + + if not sql: + continue + + # Normalize SQL (remove values, keep structure) + normalized = self._normalize_sql(sql) + if normalized not in query_patterns: + query_patterns[normalized] = [] + query_patterns[normalized].append(query) + + # Warn about patterns that appear many times (potential N+1) + for pattern, query_list in query_patterns.items(): + if len(query_list) >= self.warn_threshold: + # Check if it's a SELECT query (not INSERT/UPDATE/DELETE) + if 'SELECT' in pattern.upper(): + warnings.warn( + f"Potential N+1 query detected: {len(query_list)} similar queries " + f"executed for pattern: {pattern[:100]}... " + f"Consider using select_related() or prefetch_related().", + UserWarning + ) + + def _normalize_sql(self, sql): + """Normalize SQL by removing values and keeping structure.""" + import re + # Remove quoted strings + sql = re.sub(r"'[^']*'", "'?'", sql) + sql = re.sub(r'"[^"]*"', '"?"', sql) + # Remove numbers + sql = re.sub(r'\b\d+\b', '?', sql) + # Normalize whitespace + sql = ' '.join(sql.split()) + return sql + diff --git a/rest_framework/optimization/mixins.py b/rest_framework/optimization/mixins.py new file mode 100644 index 000000000..e3a4aad1e --- /dev/null +++ b/rest_framework/optimization/mixins.py @@ -0,0 +1,119 @@ +""" +Mixins for automatic query optimization in Django REST Framework viewsets. + +This module provides mixins that automatically optimize querysets based on +serializer field analysis. +""" + +import warnings + +from django.db.models import QuerySet + +from rest_framework.settings import api_settings + +from rest_framework.optimization.optimizer import optimize_queryset +from rest_framework.optimization.query_analyzer import detect_n_plus_one + + +class OptimizedQuerySetMixin: + """ + Mixin that automatically optimizes querysets based on serializer analysis. + + This mixin can be added to any GenericAPIView or ViewSet to automatically + apply select_related and prefetch_related optimizations based on the + serializer's field definitions. + + Usage: + class MyViewSet(OptimizedQuerySetMixin, ModelViewSet): + queryset = MyModel.objects.all() + serializer_class = MySerializer + + You can also explicitly specify optimizations: + class MyViewSet(OptimizedQuerySetMixin, ModelViewSet): + queryset = MyModel.objects.all() + serializer_class = MySerializer + select_related_fields = ['author', 'category'] + prefetch_related_fields = ['tags', 'comments'] + + Settings: + - ENABLE_QUERY_OPTIMIZATION: Enable/disable automatic optimization (default: True) + - WARN_ON_N_PLUS_ONE: Show warnings when N+1 queries are detected (default: True in DEBUG) + """ + + # Explicit optimization fields (optional) + select_related_fields = None + prefetch_related_fields = None + + # Control optimization behavior + enable_auto_optimization = True + warn_on_n_plus_one = None + + def get_queryset(self): + """ + Get the queryset with automatic optimizations applied. + + This method extends the base get_queryset() to automatically apply + select_related and prefetch_related based on serializer analysis. + """ + queryset = super().get_queryset() + + # Check if optimization is enabled + enable_optimization = getattr( + api_settings, + 'ENABLE_QUERY_OPTIMIZATION', + self.enable_auto_optimization + ) + + if not enable_optimization: + return queryset + + # Get serializer class + serializer_class = self.get_serializer_class() + if not serializer_class: + return queryset + + # Optimize queryset + try: + queryset = optimize_queryset( + queryset, + serializer_class, + select_related=self.select_related_fields, + prefetch_related=self.prefetch_related_fields, + auto_optimize=self.enable_auto_optimization + ) + except Exception as e: + # If optimization fails, log warning but don't break + if self._should_warn(): + warnings.warn( + f"Query optimization failed: {e}. " + f"Continuing with unoptimized queryset.", + UserWarning + ) + return queryset + + # Check for N+1 queries and warn if enabled + if self._should_warn(): + warnings_list = detect_n_plus_one(serializer_class, queryset) + for warning_msg in warnings_list: + warnings.warn(warning_msg, UserWarning) + + return queryset + + def _should_warn(self): + """Determine if warnings should be shown.""" + if self.warn_on_n_plus_one is not None: + return self.warn_on_n_plus_one + + # Default: warn in DEBUG mode + from django.conf import settings + warn_on_n_plus_one = getattr( + api_settings, + 'WARN_ON_N_PLUS_ONE', + getattr(settings, 'DEBUG', False) + ) + return warn_on_n_plus_one + + +# Backward compatibility alias +QueryOptimizationMixin = OptimizedQuerySetMixin + diff --git a/rest_framework/optimization/optimizer.py b/rest_framework/optimization/optimizer.py new file mode 100644 index 000000000..8ba9a0a49 --- /dev/null +++ b/rest_framework/optimization/optimizer.py @@ -0,0 +1,163 @@ +""" +Query optimizer for automatically optimizing querysets based on serializer fields. + +This module provides utilities to automatically apply select_related and +prefetch_related optimizations to querysets based on serializer field analysis. +""" + +from django.db.models import QuerySet + +from rest_framework import serializers + +from rest_framework.optimization.query_analyzer import QueryAnalyzer + + +def analyze_serializer_fields(serializer_class): + """ + Analyze a serializer class to identify required query optimizations. + + Args: + serializer_class: The serializer class to analyze + + Returns: + Dictionary with 'select_related' and 'prefetch_related' lists + """ + analyzer = QueryAnalyzer(serializer_class) + return analyzer.analyze() + + +def optimize_queryset( + queryset, + serializer_class, + select_related=None, + prefetch_related=None, + auto_optimize=True +): + """ + Optimize a queryset based on serializer analysis and/or explicit parameters. + + Args: + queryset: The queryset to optimize + serializer_class: The serializer class that will be used + select_related: Explicit list of fields for select_related (optional) + prefetch_related: Explicit list of fields for prefetch_related (optional) + auto_optimize: If True, automatically analyze serializer and apply optimizations + + Returns: + Optimized queryset + """ + if not isinstance(queryset, QuerySet): + return queryset + + # Start with the original queryset + optimized = queryset + + # Auto-optimize based on serializer analysis + if auto_optimize: + analysis = analyze_serializer_fields(serializer_class) + + # Merge auto-detected with explicit parameters + if select_related is None: + select_related = analysis.get('select_related', []) + else: + # Merge lists, avoiding duplicates + auto_select = analysis.get('select_related', []) + select_related = list(set(select_related + auto_select)) + + if prefetch_related is None: + prefetch_related = analysis.get('prefetch_related', []) + else: + # Merge lists, avoiding duplicates + auto_prefetch = analysis.get('prefetch_related', []) + prefetch_related = list(set(prefetch_related + auto_prefetch)) + + # Apply select_related + if select_related: + # Check if queryset already has select_related + existing_select = getattr(optimized.query, 'select_related', {}) + + # Handle case where select_related is True (all fields selected) + if existing_select is True: + # All fields already selected, skip + new_select = [] + elif isinstance(existing_select, dict): + # Only add fields that aren't already selected + new_select = [ + field for field in select_related + if field not in existing_select and not any( + field.startswith(sel) for sel in existing_select.keys() + ) + ] + else: + # Empty or unknown format, add all + new_select = select_related + + if new_select: + if len(new_select) == 1: + optimized = optimized.select_related(new_select[0]) + else: + optimized = optimized.select_related(*new_select) + + # Apply prefetch_related + if prefetch_related: + # Check if queryset already has prefetch_related + existing_prefetch = getattr(optimized.query, 'prefetch_related_lookups', set()) + + # Only add fields that aren't already prefetched + new_prefetch = [ + field for field in prefetch_related + if field not in existing_prefetch and not any( + field.startswith(pref) for pref in existing_prefetch + ) + ] + + if new_prefetch: + for field in new_prefetch: + optimized = optimized.prefetch_related(field) + + return optimized + + +def get_optimization_suggestions(serializer_class): + """ + Get optimization suggestions for a serializer class. + + Args: + serializer_class: The serializer class to analyze + + Returns: + Dictionary with optimization suggestions and code examples + """ + analysis = analyze_serializer_fields(serializer_class) + + suggestions = { + 'select_related': analysis.get('select_related', []), + 'prefetch_related': analysis.get('prefetch_related', []), + 'nested_serializers': analysis.get('nested_serializers', []), + 'code_example': None + } + + # Generate code example + if suggestions['select_related'] or suggestions['prefetch_related']: + parts = [] + + if suggestions['select_related']: + if len(suggestions['select_related']) == 1: + parts.append(f".select_related('{suggestions['select_related'][0]}')") + else: + fields_str = "', '".join(suggestions['select_related']) + parts.append(f".select_related('{fields_str}')") + + if suggestions['prefetch_related']: + for field in suggestions['prefetch_related']: + parts.append(f".prefetch_related('{field}')") + + if parts: + suggestions['code_example'] = ( + "def get_queryset(self):\n" + " queryset = super().get_queryset()\n" + f" return queryset{''.join(parts)}" + ) + + return suggestions + diff --git a/rest_framework/optimization/query_analyzer.py b/rest_framework/optimization/query_analyzer.py new file mode 100644 index 000000000..d34d4b310 --- /dev/null +++ b/rest_framework/optimization/query_analyzer.py @@ -0,0 +1,248 @@ +""" +Query analyzer for detecting N+1 query problems in Django REST Framework. + +This module provides utilities to analyze serializer fields and detect potential +N+1 query issues before they occur. +""" + +import warnings + +from django.db import models +from django.db.models import ForeignKey, ManyToManyField, OneToOneField + +from rest_framework import serializers +from rest_framework.relations import RelatedField, ManyRelatedField +from rest_framework.utils import model_meta + + +class QueryAnalyzer: + """ + Analyzes serializer fields to detect potential N+1 query problems. + + This class examines serializer field definitions to identify relationships + that may cause N+1 queries when serializing querysets. + """ + + def __init__(self, serializer_class): + """ + Initialize the analyzer with a serializer class. + + Args: + serializer_class: The serializer class to analyze + """ + self.serializer_class = serializer_class + self._field_analysis = None + + def analyze(self): + """ + Analyze the serializer and return a dictionary with optimization suggestions. + + Returns: + Dictionary containing: + - select_related: List of fields that should use select_related + - prefetch_related: List of fields that should use prefetch_related + - nested_serializers: List of nested serializer fields + """ + if self._field_analysis is None: + self._field_analysis = self._analyze_fields() + return self._field_analysis + + def _analyze_fields(self): + """Analyze serializer fields to identify relationships.""" + analysis = { + 'select_related': [], + 'prefetch_related': [], + 'nested_serializers': [], + 'warnings': [] + } + + if not issubclass(self.serializer_class, serializers.ModelSerializer): + return analysis + + # Get the model from the serializer + model = getattr(self.serializer_class.Meta, 'model', None) + if not model: + return analysis + + # Get field info using DRF's utility + try: + field_info = model_meta.get_field_info(model) + except Exception: + return analysis + + # Analyze declared fields + serializer = self.serializer_class() + fields = serializer.fields + + for field_name, field in fields.items(): + # Analyze fields that are readable (not write_only) + # This includes read_only fields and fields that can be both read and written + if not field.write_only: + self._analyze_field(field_name, field, model, field_info, analysis) + + return analysis + + def _analyze_field(self, field_name, field, model, field_info, analysis): + """Analyze a single field for potential N+1 issues.""" + source = getattr(field, 'source', field_name) + source_parts = source.split('.') + base_field_name = source_parts[0] + + # Check if it's a ManyRelatedField (many=True on RelatedField) + # This handles custom fields like PrimaryKeyRelatedField(many=True) + if isinstance(field, ManyRelatedField): + # ManyToMany or reverse relationship - use prefetch_related + if base_field_name not in analysis['prefetch_related']: + analysis['prefetch_related'].append(base_field_name) + return # Early return since ManyRelatedField is handled + + # Check if it's a related field + if isinstance(field, RelatedField): + # Check if it's in the model's relationships + if base_field_name in field_info.relations: + relation_info = field_info.relations[base_field_name] + + if not relation_info.to_many: + # ForeignKey or OneToOneField - use select_related + if base_field_name not in analysis['select_related']: + analysis['select_related'].append(base_field_name) + + # Check for nested relationships + if len(source_parts) > 1 and relation_info.related_model: + self._analyze_nested_relationship( + relation_info.related_model, source_parts[1:], analysis + ) + else: + # ManyToMany or reverse relationship - use prefetch_related + if base_field_name not in analysis['prefetch_related']: + analysis['prefetch_related'].append(base_field_name) + else: + # Field not in relations, but might be a custom field that maps to a model field + # Check if the field name matches a model ManyToMany field + try: + model_field = model._meta.get_field(base_field_name) + if isinstance(model_field, ManyToManyField): + if base_field_name not in analysis['prefetch_related']: + analysis['prefetch_related'].append(base_field_name) + except (models.FieldDoesNotExist, AttributeError): + # Field doesn't exist on model, might be a property or method + pass + + # Check if it's a nested serializer + elif isinstance(field, serializers.Serializer): + analysis['nested_serializers'].append(field_name) + + # First, ensure the base relationship is optimized + if base_field_name in field_info.relations: + relation_info = field_info.relations[base_field_name] + if not relation_info.to_many: + # ForeignKey or OneToOneField - use select_related + if base_field_name not in analysis['select_related']: + analysis['select_related'].append(base_field_name) + else: + # ManyToMany or reverse relationship - use prefetch_related + if base_field_name not in analysis['prefetch_related']: + analysis['prefetch_related'].append(base_field_name) + + # Check if the nested serializer has a model for deeper analysis + try: + if hasattr(field, 'Meta') and hasattr(field.Meta, 'model'): + nested_model = field.Meta.model + # Recursively analyze nested serializer + nested_analyzer = QueryAnalyzer(field.__class__) + nested_analysis = nested_analyzer.analyze() + + # Merge nested analysis + if source_parts: + base_field = source_parts[0] + # Add nested select_related/prefetch_related + for nested_field in nested_analysis.get('select_related', []): + full_path = f"{base_field}__{nested_field}" + if full_path not in analysis['select_related']: + analysis['select_related'].append(full_path) + + for nested_field in nested_analysis.get('prefetch_related', []): + full_path = f"{base_field}__{nested_field}" + if full_path not in analysis['prefetch_related']: + analysis['prefetch_related'].append(full_path) + except Exception: + # If nested serializer analysis fails, we've already handled the base relationship above + pass + + + def _analyze_nested_relationship(self, related_model, path_parts, analysis): + """Analyze nested relationships (e.g., 'author__profile').""" + if not path_parts: + return + + try: + field = related_model._meta.get_field(path_parts[0]) + if isinstance(field, (ForeignKey, OneToOneField)): + full_path = '__'.join(path_parts) + if full_path not in analysis['select_related']: + analysis['select_related'].append(full_path) + elif isinstance(field, ManyToManyField): + full_path = '__'.join(path_parts) + if full_path not in analysis['prefetch_related']: + analysis['prefetch_related'].append(full_path) + except models.FieldDoesNotExist: + pass + + +def detect_n_plus_one(serializer_class, queryset): + """ + Detect potential N+1 query issues for a serializer and queryset. + + Args: + serializer_class: The serializer class to analyze + queryset: The queryset that will be serialized + + Returns: + List of warning messages about potential N+1 queries + """ + warnings_list = [] + + if not hasattr(queryset, 'query'): + # Not a queryset, can't analyze + return warnings_list + + analyzer = QueryAnalyzer(serializer_class) + analysis = analyzer.analyze() + + # Check if queryset has optimizations + query = queryset.query + select_related = getattr(query, 'select_related', {}) + prefetch_related = getattr(query, 'prefetch_related_lookups', set()) + + # Check for missing select_related + for field in analysis.get('select_related', []): + # If select_related is True, all fields are selected + if select_related is True: + continue + elif isinstance(select_related, dict): + if field not in select_related and not any( + field.startswith(sel) for sel in select_related.keys() + ): + warnings_list.append( + f"Potential N+1 query detected: Consider using " + f"select_related('{field}') for field '{field}'" + ) + else: + # No select_related, add warning + warnings_list.append( + f"Potential N+1 query detected: Consider using " + f"select_related('{field}') for field '{field}'" + ) + + # Check for missing prefetch_related + for field in analysis.get('prefetch_related', []): + if field not in prefetch_related and not any( + field.startswith(pref) for pref in prefetch_related + ): + warnings_list.append( + f"Potential N+1 query detected: Consider using " + f"prefetch_related('{field}') for field '{field}'" + ) + + return warnings_list + diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 50e3ad40e..364bc7db1 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -128,6 +128,10 @@ DEFAULTS = { 'retrieve': 'read', 'destroy': 'delete' }, + + # Query Optimization + 'ENABLE_QUERY_OPTIMIZATION': True, + 'WARN_ON_N_PLUS_ONE': None, # None means auto-detect based on DEBUG }