diff --git a/tests/test_serializer.py b/tests/test_serializer.py index 1d9efaa43..9934a0780 100644 --- a/tests/test_serializer.py +++ b/tests/test_serializer.py @@ -636,6 +636,8 @@ class Test2555Regression: nested = NestedSerializer() serializer = ParentSerializer(data={}, context={'foo': 'bar'}) + serializer.context.pop('request', None) + serializer.fields['nested'].context.pop('request', None) assert serializer.context == {'foo': 'bar'} assert serializer.fields['nested'].context == {'foo': 'bar'} diff --git a/tests/test_views.py b/tests/test_views.py index 2648c9fb3..e0a8defde 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -1,9 +1,12 @@ import copy +import sys +import pytest from django.test import TestCase -from rest_framework import status +from rest_framework import serializers, status from rest_framework.decorators import api_view +from rest_framework.request import Request from rest_framework.response import Response from rest_framework.settings import APISettings, api_settings from rest_framework.test import APIRequestFactory @@ -22,6 +25,16 @@ class BasicView(APIView): return Response({'method': 'POST', 'data': request.data}) +class ViewWithSerializer(APIView): + serializer_class = serializers.Serializer + + def get(self, request, *args, **kwargs): + serializer = self.serializer_class() + response = Response() + response.serializer = serializer + return response + + @api_view(['GET', 'POST', 'PUT', 'PATCH']) def basic_view(request): if request.method == 'GET': @@ -136,3 +149,20 @@ class TestCustomSettings(TestCase): response = self.view(request) assert response.status_code == 400 assert response.data == {'error': 'SyntaxError'} + + +class TestRequestContextVar(TestCase): + def setUp(self): + self.view = ViewWithSerializer.as_view() + + @pytest.mark.skipif( + sys.version_info < (3, 7), + reason="contextvars require Python 3.7 or higher", + ) + def test_request_context_var(self): + request = factory.get('/', content_type='application/json') + response = self.view(request) + assert Request.get_current() is not None + assert Request.get_current().method == request.method + assert Request.get_current().path == request.path + assert response.serializer.context['request'] is Request.get_current()