mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-07-18 04:02:35 +03:00
Merge 6a7df3bb09
into 552a67acde
This commit is contained in:
commit
349b870e2c
|
@ -8,6 +8,7 @@ The wrapped request then offers a richer API, in particular :
|
||||||
- full support of PUT method, including support for file uploads
|
- full support of PUT method, including support for file uploads
|
||||||
- form overloading of HTTP method, content type and content
|
- form overloading of HTTP method, content type and content
|
||||||
"""
|
"""
|
||||||
|
import contextvars
|
||||||
import io
|
import io
|
||||||
import sys
|
import sys
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
@ -149,6 +150,8 @@ class Request:
|
||||||
authenticating the request's user.
|
authenticating the request's user.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_context_instance = contextvars.ContextVar('request')
|
||||||
|
|
||||||
def __init__(self, request, parsers=None, authenticators=None,
|
def __init__(self, request, parsers=None, authenticators=None,
|
||||||
negotiator=None, parser_context=None):
|
negotiator=None, parser_context=None):
|
||||||
assert isinstance(request, HttpRequest), (
|
assert isinstance(request, HttpRequest), (
|
||||||
|
@ -458,3 +461,11 @@ class Request:
|
||||||
# Hack to allow our exception handler to force choice of
|
# Hack to allow our exception handler to force choice of
|
||||||
# plaintext or html error responses.
|
# plaintext or html error responses.
|
||||||
self._request.is_ajax = lambda: value
|
self._request.is_ajax = lambda: value
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_current(cls, default=None):
|
||||||
|
return cls._context_instance.get(default)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def set_current(cls, value):
|
||||||
|
return cls._context_instance.set(value)
|
||||||
|
|
|
@ -28,7 +28,8 @@ from django.utils.translation import gettext_lazy as _
|
||||||
|
|
||||||
from rest_framework.compat import postgres_fields
|
from rest_framework.compat import postgres_fields
|
||||||
from rest_framework.exceptions import ErrorDetail, ValidationError
|
from rest_framework.exceptions import ErrorDetail, ValidationError
|
||||||
from rest_framework.fields import get_error_detail
|
from rest_framework.fields import get_error_detail, set_value
|
||||||
|
from rest_framework.request import Request
|
||||||
from rest_framework.settings import api_settings
|
from rest_framework.settings import api_settings
|
||||||
from rest_framework.utils import html, model_meta, representation
|
from rest_framework.utils import html, model_meta, representation
|
||||||
from rest_framework.utils.field_mapping import (
|
from rest_framework.utils.field_mapping import (
|
||||||
|
@ -115,6 +116,12 @@ class BaseSerializer(Field):
|
||||||
self.partial = kwargs.pop('partial', False)
|
self.partial = kwargs.pop('partial', False)
|
||||||
self._context = kwargs.pop('context', {})
|
self._context = kwargs.pop('context', {})
|
||||||
kwargs.pop('many', None)
|
kwargs.pop('many', None)
|
||||||
|
|
||||||
|
if isinstance(self._context, dict) and ('request' not in self._context):
|
||||||
|
request = Request.get_current()
|
||||||
|
if request:
|
||||||
|
self._context['request'] = request
|
||||||
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs):
|
def __new__(cls, *args, **kwargs):
|
||||||
|
|
|
@ -489,9 +489,11 @@ class APIView(View):
|
||||||
"""
|
"""
|
||||||
self.args = args
|
self.args = args
|
||||||
self.kwargs = kwargs
|
self.kwargs = kwargs
|
||||||
|
self.headers = self.default_response_headers # deprecate?
|
||||||
|
|
||||||
request = self.initialize_request(request, *args, **kwargs)
|
request = self.initialize_request(request, *args, **kwargs)
|
||||||
self.request = request
|
self.request = request
|
||||||
self.headers = self.default_response_headers # deprecate?
|
Request.set_current(request)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.initial(request, *args, **kwargs)
|
self.initial(request, *args, **kwargs)
|
||||||
|
|
|
@ -637,6 +637,8 @@ class Test2555Regression:
|
||||||
nested = NestedSerializer()
|
nested = NestedSerializer()
|
||||||
|
|
||||||
serializer = ParentSerializer(data={}, context={'foo': 'bar'})
|
serializer = ParentSerializer(data={}, context={'foo': 'bar'})
|
||||||
|
serializer.context.pop('request', None)
|
||||||
|
serializer.fields['nested'].context.pop('request', None)
|
||||||
assert serializer.context == {'foo': 'bar'}
|
assert serializer.context == {'foo': 'bar'}
|
||||||
assert serializer.fields['nested'].context == {'foo': 'bar'}
|
assert serializer.fields['nested'].context == {'foo': 'bar'}
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,12 @@
|
||||||
import copy
|
import copy
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
|
||||||
from rest_framework import status
|
from rest_framework import serializers, status
|
||||||
from rest_framework.decorators import api_view
|
from rest_framework.decorators import api_view
|
||||||
|
from rest_framework.request import Request
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from rest_framework.settings import APISettings, api_settings
|
from rest_framework.settings import APISettings, api_settings
|
||||||
from rest_framework.test import APIRequestFactory
|
from rest_framework.test import APIRequestFactory
|
||||||
|
@ -22,6 +25,16 @@ class BasicView(APIView):
|
||||||
return Response({'method': 'POST', 'data': request.data})
|
return Response({'method': 'POST', 'data': request.data})
|
||||||
|
|
||||||
|
|
||||||
|
class ViewWithSerializer(APIView):
|
||||||
|
serializer_class = serializers.Serializer
|
||||||
|
|
||||||
|
def get(self, request, *args, **kwargs):
|
||||||
|
serializer = self.serializer_class()
|
||||||
|
response = Response()
|
||||||
|
response.serializer = serializer
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
@api_view(['GET', 'POST', 'PUT', 'PATCH'])
|
@api_view(['GET', 'POST', 'PUT', 'PATCH'])
|
||||||
def basic_view(request):
|
def basic_view(request):
|
||||||
if request.method == 'GET':
|
if request.method == 'GET':
|
||||||
|
@ -136,3 +149,20 @@ class TestCustomSettings(TestCase):
|
||||||
response = self.view(request)
|
response = self.view(request)
|
||||||
assert response.status_code == 400
|
assert response.status_code == 400
|
||||||
assert response.data == {'error': 'SyntaxError'}
|
assert response.data == {'error': 'SyntaxError'}
|
||||||
|
|
||||||
|
|
||||||
|
class TestRequestContextVar(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.view = ViewWithSerializer.as_view()
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
sys.version_info < (3, 7),
|
||||||
|
reason="contextvars require Python 3.7 or higher",
|
||||||
|
)
|
||||||
|
def test_request_context_var(self):
|
||||||
|
request = factory.get('/', content_type='application/json')
|
||||||
|
response = self.view(request)
|
||||||
|
assert Request.get_current() is not None
|
||||||
|
assert Request.get_current().method == request.method
|
||||||
|
assert Request.get_current().path == request.path
|
||||||
|
assert response.serializer.context['request'] is Request.get_current()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user