From 06d21eb2467e0f7f53f79945faa920314d3ed276 Mon Sep 17 00:00:00 2001 From: dalerzafarovich Date: Tue, 21 Mar 2023 23:35:58 +0500 Subject: [PATCH 1/2] contextvars for request context --- rest_framework/request.py | 11 +++++++++++ rest_framework/serializers.py | 7 +++++++ rest_framework/views.py | 4 +++- 3 files changed, 21 insertions(+), 1 deletion(-) diff --git a/rest_framework/request.py b/rest_framework/request.py index 93109226d..1b1a0a649 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -8,6 +8,7 @@ The wrapped request then offers a richer API, in particular : - full support of PUT method, including support for file uploads - form overloading of HTTP method, content type and content """ +import contextvars import io import sys from contextlib import contextmanager @@ -149,6 +150,8 @@ class Request: authenticating the request's user. """ + _context_instance = contextvars.ContextVar('request') + def __init__(self, request, parsers=None, authenticators=None, negotiator=None, parser_context=None): assert isinstance(request, HttpRequest), ( @@ -458,3 +461,11 @@ class Request: # Hack to allow our exception handler to force choice of # plaintext or html error responses. self._request.is_ajax = lambda: value + + @classmethod + def get_current(cls, default=None): + return cls._context_instance.get(default) + + @classmethod + def set_current(cls, value): + return cls._context_instance.set(value) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index e27f8a47c..815e97e29 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -29,6 +29,7 @@ from django.utils.translation import gettext_lazy as _ from rest_framework.compat import postgres_fields from rest_framework.exceptions import ErrorDetail, ValidationError from rest_framework.fields import get_error_detail, set_value +from rest_framework.request import Request from rest_framework.settings import api_settings from rest_framework.utils import html, model_meta, representation from rest_framework.utils.field_mapping import ( @@ -115,6 +116,12 @@ class BaseSerializer(Field): self.partial = kwargs.pop('partial', False) self._context = kwargs.pop('context', {}) kwargs.pop('many', None) + + if isinstance(self._context, dict) and ('request' not in self._context): + request = Request.get_current() + if request: + self._context['request'] = request + super().__init__(**kwargs) def __new__(cls, *args, **kwargs): diff --git a/rest_framework/views.py b/rest_framework/views.py index 4c30029fd..2bc4fe32f 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -489,9 +489,11 @@ class APIView(View): """ self.args = args self.kwargs = kwargs + self.headers = self.default_response_headers # deprecate? + request = self.initialize_request(request, *args, **kwargs) self.request = request - self.headers = self.default_response_headers # deprecate? + Request.set_current(request) try: self.initial(request, *args, **kwargs) From e6106facc90263839116bc74d4ab5608efae00bf Mon Sep 17 00:00:00 2001 From: dalerzafarovich Date: Tue, 21 Mar 2023 23:36:19 +0500 Subject: [PATCH 2/2] tests --- tests/test_serializer.py | 2 ++ tests/test_views.py | 32 +++++++++++++++++++++++++++++++- 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/tests/test_serializer.py b/tests/test_serializer.py index 1d9efaa43..9934a0780 100644 --- a/tests/test_serializer.py +++ b/tests/test_serializer.py @@ -636,6 +636,8 @@ class Test2555Regression: nested = NestedSerializer() serializer = ParentSerializer(data={}, context={'foo': 'bar'}) + serializer.context.pop('request', None) + serializer.fields['nested'].context.pop('request', None) assert serializer.context == {'foo': 'bar'} assert serializer.fields['nested'].context == {'foo': 'bar'} diff --git a/tests/test_views.py b/tests/test_views.py index 2648c9fb3..e0a8defde 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -1,9 +1,12 @@ import copy +import sys +import pytest from django.test import TestCase -from rest_framework import status +from rest_framework import serializers, status from rest_framework.decorators import api_view +from rest_framework.request import Request from rest_framework.response import Response from rest_framework.settings import APISettings, api_settings from rest_framework.test import APIRequestFactory @@ -22,6 +25,16 @@ class BasicView(APIView): return Response({'method': 'POST', 'data': request.data}) +class ViewWithSerializer(APIView): + serializer_class = serializers.Serializer + + def get(self, request, *args, **kwargs): + serializer = self.serializer_class() + response = Response() + response.serializer = serializer + return response + + @api_view(['GET', 'POST', 'PUT', 'PATCH']) def basic_view(request): if request.method == 'GET': @@ -136,3 +149,20 @@ class TestCustomSettings(TestCase): response = self.view(request) assert response.status_code == 400 assert response.data == {'error': 'SyntaxError'} + + +class TestRequestContextVar(TestCase): + def setUp(self): + self.view = ViewWithSerializer.as_view() + + @pytest.mark.skipif( + sys.version_info < (3, 7), + reason="contextvars require Python 3.7 or higher", + ) + def test_request_context_var(self): + request = factory.get('/', content_type='application/json') + response = self.view(request) + assert Request.get_current() is not None + assert Request.get_current().method == request.method + assert Request.get_current().path == request.path + assert response.serializer.context['request'] is Request.get_current()