From 700cbf19663099dad883c8fb602414556f37d186 Mon Sep 17 00:00:00 2001 From: Oguntunde Caleb Date: Fri, 12 May 2023 00:51:53 +0100 Subject: [PATCH] Add middleware classes support that run after drf mutate request --- .gitignore | 1 + docs/api-guide/settings.md | 7 ++ rest_framework/settings.py | 201 ++++++++++++++++++------------------- rest_framework/views.py | 122 +++++++++++++--------- 4 files changed, 178 insertions(+), 153 deletions(-) diff --git a/.gitignore b/.gitignore index 641714d16..67ab45fce 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ /env/ MANIFEST coverage.* +venv !.github !.gitignore diff --git a/docs/api-guide/settings.md b/docs/api-guide/settings.md index d42000260..4097c4171 100644 --- a/docs/api-guide/settings.md +++ b/docs/api-guide/settings.md @@ -41,6 +41,13 @@ The `api_settings` object will check for any user-defined settings, and otherwis *The following settings control the basic API policies, and are applied to every `APIView` class-based view, or `@api_view` function based view.* +#### MIDDLEWARE_CLASSES + +A list or tuple of middleware classes, that is run prior to calling the method handler. + +Default: `[]` + + #### DEFAULT_RENDERER_CLASSES A list or tuple of renderer classes, that determines the default set of renderers that may be used when returning a `Response` object. diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 96b664574..5f0406ba9 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -19,6 +19,7 @@ REST framework settings, checking for user settings first, then falling back to the defaults. """ from django.conf import settings + # Import from `django.core.signals` instead of the official location # `django.test.signals` to avoid importing the test module unnecessarily. from django.core.signals import setting_changed @@ -27,133 +28,118 @@ from django.utils.module_loading import import_string from rest_framework import ISO_8601 DEFAULTS = { + # custom middleware class to run prior to calling the method handler + "MIDDLEWARE_CLASSES": [], # Base API policies - 'DEFAULT_RENDERER_CLASSES': [ - 'rest_framework.renderers.JSONRenderer', - 'rest_framework.renderers.BrowsableAPIRenderer', + "DEFAULT_RENDERER_CLASSES": [ + "rest_framework.renderers.JSONRenderer", + "rest_framework.renderers.BrowsableAPIRenderer", ], - 'DEFAULT_PARSER_CLASSES': [ - 'rest_framework.parsers.JSONParser', - 'rest_framework.parsers.FormParser', - 'rest_framework.parsers.MultiPartParser' + "DEFAULT_PARSER_CLASSES": [ + "rest_framework.parsers.JSONParser", + "rest_framework.parsers.FormParser", + "rest_framework.parsers.MultiPartParser", ], - 'DEFAULT_AUTHENTICATION_CLASSES': [ - 'rest_framework.authentication.SessionAuthentication', - 'rest_framework.authentication.BasicAuthentication' + "DEFAULT_AUTHENTICATION_CLASSES": [ + "rest_framework.authentication.SessionAuthentication", + "rest_framework.authentication.BasicAuthentication", ], - 'DEFAULT_PERMISSION_CLASSES': [ - 'rest_framework.permissions.AllowAny', + "DEFAULT_PERMISSION_CLASSES": [ + "rest_framework.permissions.AllowAny", ], - 'DEFAULT_THROTTLE_CLASSES': [], - 'DEFAULT_CONTENT_NEGOTIATION_CLASS': 'rest_framework.negotiation.DefaultContentNegotiation', - 'DEFAULT_METADATA_CLASS': 'rest_framework.metadata.SimpleMetadata', - 'DEFAULT_VERSIONING_CLASS': None, - + "DEFAULT_THROTTLE_CLASSES": [], + "DEFAULT_CONTENT_NEGOTIATION_CLASS": "rest_framework.negotiation.DefaultContentNegotiation", + "DEFAULT_METADATA_CLASS": "rest_framework.metadata.SimpleMetadata", + "DEFAULT_VERSIONING_CLASS": None, # Generic view behavior - 'DEFAULT_PAGINATION_CLASS': None, - 'DEFAULT_FILTER_BACKENDS': [], - + "DEFAULT_PAGINATION_CLASS": None, + "DEFAULT_FILTER_BACKENDS": [], # Schema - 'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.openapi.AutoSchema', - + "DEFAULT_SCHEMA_CLASS": "rest_framework.schemas.openapi.AutoSchema", # Throttling - 'DEFAULT_THROTTLE_RATES': { - 'user': None, - 'anon': None, + "DEFAULT_THROTTLE_RATES": { + "user": None, + "anon": None, }, - 'NUM_PROXIES': None, - + "NUM_PROXIES": None, # Pagination - 'PAGE_SIZE': None, - + "PAGE_SIZE": None, # Filtering - 'SEARCH_PARAM': 'search', - 'ORDERING_PARAM': 'ordering', - + "SEARCH_PARAM": "search", + "ORDERING_PARAM": "ordering", # Versioning - 'DEFAULT_VERSION': None, - 'ALLOWED_VERSIONS': None, - 'VERSION_PARAM': 'version', - + "DEFAULT_VERSION": None, + "ALLOWED_VERSIONS": None, + "VERSION_PARAM": "version", # Authentication - 'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser', - 'UNAUTHENTICATED_TOKEN': None, - + "UNAUTHENTICATED_USER": "django.contrib.auth.models.AnonymousUser", + "UNAUTHENTICATED_TOKEN": None, # View configuration - 'VIEW_NAME_FUNCTION': 'rest_framework.views.get_view_name', - 'VIEW_DESCRIPTION_FUNCTION': 'rest_framework.views.get_view_description', - + "VIEW_NAME_FUNCTION": "rest_framework.views.get_view_name", + "VIEW_DESCRIPTION_FUNCTION": "rest_framework.views.get_view_description", # Exception handling - 'EXCEPTION_HANDLER': 'rest_framework.views.exception_handler', - 'NON_FIELD_ERRORS_KEY': 'non_field_errors', - + "EXCEPTION_HANDLER": "rest_framework.views.exception_handler", + "NON_FIELD_ERRORS_KEY": "non_field_errors", # Testing - 'TEST_REQUEST_RENDERER_CLASSES': [ - 'rest_framework.renderers.MultiPartRenderer', - 'rest_framework.renderers.JSONRenderer' + "TEST_REQUEST_RENDERER_CLASSES": [ + "rest_framework.renderers.MultiPartRenderer", + "rest_framework.renderers.JSONRenderer", ], - 'TEST_REQUEST_DEFAULT_FORMAT': 'multipart', - + "TEST_REQUEST_DEFAULT_FORMAT": "multipart", # Hyperlink settings - 'URL_FORMAT_OVERRIDE': 'format', - 'FORMAT_SUFFIX_KWARG': 'format', - 'URL_FIELD_NAME': 'url', - + "URL_FORMAT_OVERRIDE": "format", + "FORMAT_SUFFIX_KWARG": "format", + "URL_FIELD_NAME": "url", # Input and output formats - 'DATE_FORMAT': ISO_8601, - 'DATE_INPUT_FORMATS': [ISO_8601], - - 'DATETIME_FORMAT': ISO_8601, - 'DATETIME_INPUT_FORMATS': [ISO_8601], - - 'TIME_FORMAT': ISO_8601, - 'TIME_INPUT_FORMATS': [ISO_8601], - + "DATE_FORMAT": ISO_8601, + "DATE_INPUT_FORMATS": [ISO_8601], + "DATETIME_FORMAT": ISO_8601, + "DATETIME_INPUT_FORMATS": [ISO_8601], + "TIME_FORMAT": ISO_8601, + "TIME_INPUT_FORMATS": [ISO_8601], # Encoding - 'UNICODE_JSON': True, - 'COMPACT_JSON': True, - 'STRICT_JSON': True, - 'COERCE_DECIMAL_TO_STRING': True, - 'UPLOADED_FILES_USE_URL': True, - + "UNICODE_JSON": True, + "COMPACT_JSON": True, + "STRICT_JSON": True, + "COERCE_DECIMAL_TO_STRING": True, + "UPLOADED_FILES_USE_URL": True, # Browseable API - 'HTML_SELECT_CUTOFF': 1000, - 'HTML_SELECT_CUTOFF_TEXT': "More than {count} items...", - + "HTML_SELECT_CUTOFF": 1000, + "HTML_SELECT_CUTOFF_TEXT": "More than {count} items...", # Schemas - 'SCHEMA_COERCE_PATH_PK': True, - 'SCHEMA_COERCE_METHOD_NAMES': { - 'retrieve': 'read', - 'destroy': 'delete' - }, + "SCHEMA_COERCE_PATH_PK": True, + "SCHEMA_COERCE_METHOD_NAMES": {"retrieve": "read", "destroy": "delete"}, } # List of settings that may be in string import notation. IMPORT_STRINGS = [ - 'DEFAULT_RENDERER_CLASSES', - 'DEFAULT_PARSER_CLASSES', - 'DEFAULT_AUTHENTICATION_CLASSES', - 'DEFAULT_PERMISSION_CLASSES', - 'DEFAULT_THROTTLE_CLASSES', - 'DEFAULT_CONTENT_NEGOTIATION_CLASS', - 'DEFAULT_METADATA_CLASS', - 'DEFAULT_VERSIONING_CLASS', - 'DEFAULT_PAGINATION_CLASS', - 'DEFAULT_FILTER_BACKENDS', - 'DEFAULT_SCHEMA_CLASS', - 'EXCEPTION_HANDLER', - 'TEST_REQUEST_RENDERER_CLASSES', - 'UNAUTHENTICATED_USER', - 'UNAUTHENTICATED_TOKEN', - 'VIEW_NAME_FUNCTION', - 'VIEW_DESCRIPTION_FUNCTION' + "MIDDLEWARE_CLASSES", + "DEFAULT_RENDERER_CLASSES", + "DEFAULT_PARSER_CLASSES", + "DEFAULT_AUTHENTICATION_CLASSES", + "DEFAULT_PERMISSION_CLASSES", + "DEFAULT_THROTTLE_CLASSES", + "DEFAULT_CONTENT_NEGOTIATION_CLASS", + "DEFAULT_METADATA_CLASS", + "DEFAULT_VERSIONING_CLASS", + "DEFAULT_PAGINATION_CLASS", + "DEFAULT_FILTER_BACKENDS", + "DEFAULT_SCHEMA_CLASS", + "EXCEPTION_HANDLER", + "TEST_REQUEST_RENDERER_CLASSES", + "UNAUTHENTICATED_USER", + "UNAUTHENTICATED_TOKEN", + "VIEW_NAME_FUNCTION", + "VIEW_DESCRIPTION_FUNCTION", ] # List of settings that have been removed REMOVED_SETTINGS = [ - 'PAGINATE_BY', 'PAGINATE_BY_PARAM', 'MAX_PAGINATE_BY', + "PAGINATE_BY", + "PAGINATE_BY_PARAM", + "MAX_PAGINATE_BY", ] @@ -178,7 +164,12 @@ def import_from_string(val, setting_name): try: return import_string(val) except ImportError as e: - msg = "Could not import '%s' for API setting '%s'. %s: %s." % (val, setting_name, e.__class__.__name__, e) + msg = "Could not import '%s' for API setting '%s'. %s: %s." % ( + val, + setting_name, + e.__class__.__name__, + e, + ) raise ImportError(msg) @@ -198,6 +189,7 @@ class APISettings: under the REST_FRAMEWORK name. It is not intended to be used by 3rd-party apps, and test helpers like `override_settings` may not work as expected. """ + def __init__(self, user_settings=None, defaults=None, import_strings=None): if user_settings: self._user_settings = self.__check_user_settings(user_settings) @@ -207,8 +199,8 @@ class APISettings: @property def user_settings(self): - if not hasattr(self, '_user_settings'): - self._user_settings = getattr(settings, 'REST_FRAMEWORK', {}) + if not hasattr(self, "_user_settings"): + self._user_settings = getattr(settings, "REST_FRAMEWORK", {}) return self._user_settings def __getattr__(self, attr): @@ -235,23 +227,26 @@ class APISettings: SETTINGS_DOC = "https://www.django-rest-framework.org/api-guide/settings/" for setting in REMOVED_SETTINGS: if setting in user_settings: - raise RuntimeError("The '%s' setting has been removed. Please refer to '%s' for available settings." % (setting, SETTINGS_DOC)) + raise RuntimeError( + "The '%s' setting has been removed. Please refer to '%s' for available settings." + % (setting, SETTINGS_DOC) + ) return user_settings def reload(self): for attr in self._cached_attrs: delattr(self, attr) self._cached_attrs.clear() - if hasattr(self, '_user_settings'): - delattr(self, '_user_settings') + if hasattr(self, "_user_settings"): + delattr(self, "_user_settings") api_settings = APISettings(None, DEFAULTS, IMPORT_STRINGS) def reload_api_settings(*args, **kwargs): - setting = kwargs['setting'] - if setting == 'REST_FRAMEWORK': + setting = kwargs["setting"] + if setting == "REST_FRAMEWORK": api_settings.reload() diff --git a/rest_framework/views.py b/rest_framework/views.py index 4c30029fd..ecf43e62b 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -27,19 +27,19 @@ def get_view_name(view): This function is the default for the `VIEW_NAME_FUNCTION` setting. """ # Name may be set by some Views, such as a ViewSet. - name = getattr(view, 'name', None) + name = getattr(view, "name", None) if name is not None: return name name = view.__class__.__name__ - name = formatting.remove_trailing_string(name, 'View') - name = formatting.remove_trailing_string(name, 'ViewSet') + name = formatting.remove_trailing_string(name, "View") + name = formatting.remove_trailing_string(name, "ViewSet") name = formatting.camelcase_to_spaces(name) # Suffix may be set by some Views, such as a ViewSet. - suffix = getattr(view, 'suffix', None) + suffix = getattr(view, "suffix", None) if suffix: - name += ' ' + suffix + name += " " + suffix return name @@ -52,9 +52,9 @@ def get_view_description(view, html=False): This function is the default for the `VIEW_DESCRIPTION_FUNCTION` setting. """ # Description may be set by some Views, such as a ViewSet. - description = getattr(view, 'description', None) + description = getattr(view, "description", None) if description is None: - description = view.__class__.__doc__ or '' + description = view.__class__.__doc__ or "" description = formatting.dedent(smart_str(description)) if html: @@ -64,7 +64,7 @@ def get_view_description(view, html=False): def set_rollback(): for db in connections.all(): - if db.settings_dict['ATOMIC_REQUESTS'] and db.in_atomic_block: + if db.settings_dict["ATOMIC_REQUESTS"] and db.in_atomic_block: db.set_rollback(True) @@ -85,15 +85,15 @@ def exception_handler(exc, context): if isinstance(exc, exceptions.APIException): headers = {} - if getattr(exc, 'auth_header', None): - headers['WWW-Authenticate'] = exc.auth_header - if getattr(exc, 'wait', None): - headers['Retry-After'] = '%d' % exc.wait + if getattr(exc, "auth_header", None): + headers["WWW-Authenticate"] = exc.auth_header + if getattr(exc, "wait", None): + headers["Retry-After"] = "%d" % exc.wait if isinstance(exc.detail, (list, dict)): data = exc.detail else: - data = {'detail': exc.detail} + data = {"detail": exc.detail} set_rollback() return Response(data, status=exc.status_code, headers=headers) @@ -102,8 +102,8 @@ def exception_handler(exc, context): class APIView(View): - # The following policies may be set at either globally, or per-view. + middleware_classes = api_settings.MIDDLEWARE_CLASSES renderer_classes = api_settings.DEFAULT_RENDERER_CLASSES parser_classes = api_settings.DEFAULT_PARSER_CLASSES authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES @@ -126,13 +126,15 @@ class APIView(View): This allows us to discover information about the view when we do URL reverse lookups. Used for breadcrumb generation. """ - if isinstance(getattr(cls, 'queryset', None), models.query.QuerySet): + if isinstance(getattr(cls, "queryset", None), models.query.QuerySet): + def force_evaluation(): raise RuntimeError( - 'Do not evaluate the `.queryset` attribute directly, ' - 'as the result will be cached and reused between requests. ' - 'Use `.all()` or call `.get_queryset()` instead.' + "Do not evaluate the `.queryset` attribute directly, " + "as the result will be cached and reused between requests. " + "Use `.all()` or call `.get_queryset()` instead." ) + cls.queryset._fetch_all = force_evaluation view = super().as_view(**initkwargs) @@ -153,10 +155,10 @@ class APIView(View): @property def default_response_headers(self): headers = { - 'Allow': ', '.join(self.allowed_methods), + "Allow": ", ".join(self.allowed_methods), } if len(self.renderer_classes) > 1: - headers['Vary'] = 'Accept' + headers["Vary"] = "Accept" return headers def http_method_not_allowed(self, request, *args, **kwargs): @@ -197,9 +199,9 @@ class APIView(View): # Note: Additionally `request` and `encoding` will also be added # to the context by the Request object. return { - 'view': self, - 'args': getattr(self, 'args', ()), - 'kwargs': getattr(self, 'kwargs', {}) + "view": self, + "args": getattr(self, "args", ()), + "kwargs": getattr(self, "kwargs", {}), } def get_renderer_context(self): @@ -210,10 +212,10 @@ class APIView(View): # Note: Additionally 'response' will also be added to the context, # by the Response object. return { - 'view': self, - 'args': getattr(self, 'args', ()), - 'kwargs': getattr(self, 'kwargs', {}), - 'request': getattr(self, 'request', None) + "view": self, + "args": getattr(self, "args", ()), + "kwargs": getattr(self, "kwargs", {}), + "request": getattr(self, "request", None), } def get_exception_handler_context(self): @@ -222,10 +224,10 @@ class APIView(View): as the `context` argument. """ return { - 'view': self, - 'args': getattr(self, 'args', ()), - 'kwargs': getattr(self, 'kwargs', {}), - 'request': getattr(self, 'request', None) + "view": self, + "args": getattr(self, "args", ()), + "kwargs": getattr(self, "kwargs", {}), + "request": getattr(self, "request", None), } def get_view_name(self): @@ -287,7 +289,7 @@ class APIView(View): """ Instantiate and return the content negotiation class to use. """ - if not getattr(self, '_negotiator', None): + if not getattr(self, "_negotiator", None): self._negotiator = self.content_negotiation_class() return self._negotiator @@ -332,8 +334,8 @@ class APIView(View): if not permission.has_permission(request, self): self.permission_denied( request, - message=getattr(permission, 'message', None), - code=getattr(permission, 'code', None) + message=getattr(permission, "message", None), + code=getattr(permission, "code", None), ) def check_object_permissions(self, request, obj): @@ -345,8 +347,8 @@ class APIView(View): if not permission.has_object_permission(request, self, obj): self.permission_denied( request, - message=getattr(permission, 'message', None), - code=getattr(permission, 'code', None) + message=getattr(permission, "message", None), + code=getattr(permission, "code", None), ) def check_throttles(self, request): @@ -363,8 +365,7 @@ class APIView(View): # Filter out `None` values which may happen in case of config / rate # changes, see #1438 durations = [ - duration for duration in throttle_durations - if duration is not None + duration for duration in throttle_durations if duration is not None ] duration = max(durations, default=None) @@ -382,6 +383,23 @@ class APIView(View): # Dispatch methods + def get_middleware_classes(self): + """ + get list of middleware class instance + """ + + return [middleware() for middleware in self.middleware_classes] + + def initialize_middleware_classes(self, request): + """ + Run custom middleware classes before prior to calling the method handler + """ + + for middleware in self.get_middleware_classes(): + middleware(request) + + return request # Return mutated request + def initialize_request(self, request, *args, **kwargs): """ Returns the initial request object. @@ -393,7 +411,7 @@ class APIView(View): parsers=self.get_parsers(), authenticators=self.get_authenticators(), negotiator=self.get_content_negotiator(), - parser_context=parser_context + parser_context=parser_context, ) def initial(self, request, *args, **kwargs): @@ -415,19 +433,21 @@ class APIView(View): self.check_permissions(request) self.check_throttles(request) + # authentication and other task ran before final mutation + self.initialize_middleware_classes(request) + def finalize_response(self, request, response, *args, **kwargs): """ Returns the final response object. """ # Make the error obvious if a proper response is not returned assert isinstance(response, HttpResponseBase), ( - 'Expected a `Response`, `HttpResponse` or `HttpStreamingResponse` ' - 'to be returned from the view, but received a `%s`' - % type(response) + "Expected a `Response`, `HttpResponse` or `HttpStreamingResponse` " + "to be returned from the view, but received a `%s`" % type(response) ) if isinstance(response, Response): - if not getattr(request, 'accepted_renderer', None): + if not getattr(request, "accepted_renderer", None): neg = self.perform_content_negotiation(request, force=True) request.accepted_renderer, request.accepted_media_type = neg @@ -436,7 +456,7 @@ class APIView(View): response.renderer_context = self.get_renderer_context() # Add new vary headers to the response instead of overwriting. - vary_headers = self.headers.pop('Vary', None) + vary_headers = self.headers.pop("Vary", None) if vary_headers is not None: patch_vary_headers(response, cc_delim_re.split(vary_headers)) @@ -450,8 +470,9 @@ class APIView(View): Handle any exception that occurs, by returning an appropriate response, or re-raising the error. """ - if isinstance(exc, (exceptions.NotAuthenticated, - exceptions.AuthenticationFailed)): + if isinstance( + exc, (exceptions.NotAuthenticated, exceptions.AuthenticationFailed) + ): # WWW-Authenticate header for 401 responses, else coerce to 403 auth_header = self.get_authenticate_header(self.request) @@ -474,8 +495,8 @@ class APIView(View): def raise_uncaught_exception(self, exc): if settings.DEBUG: request = self.request - renderer_format = getattr(request.accepted_renderer, 'format') - use_plaintext_traceback = renderer_format not in ('html', 'api', 'admin') + renderer_format = getattr(request.accepted_renderer, "format") + use_plaintext_traceback = renderer_format not in ("html", "api", "admin") request.force_plaintext_errors(use_plaintext_traceback) raise exc @@ -498,8 +519,9 @@ class APIView(View): # Get the appropriate handler method if request.method.lower() in self.http_method_names: - handler = getattr(self, request.method.lower(), - self.http_method_not_allowed) + handler = getattr( + self, request.method.lower(), self.http_method_not_allowed + ) else: handler = self.http_method_not_allowed