Update to support async dispatch in AsyncAPIView

This commit is contained in:
enrico 2022-08-25 07:41:38 +08:00
parent 8fa1b7c2b7
commit daac20aa85
3 changed files with 45 additions and 16 deletions

View File

@ -41,8 +41,8 @@ except ImportError:
uritemplate = None uritemplate = None
# async_to_sync is required for async View support # async_to_sync is required for async view support
if django.VERSION >= (3, 1): if django.VERSION >= (4, 1):
from asgiref.sync import async_to_sync from asgiref.sync import async_to_sync
else: else:
async_to_sync = None async_to_sync = None

View File

@ -14,7 +14,6 @@ from django.views.decorators.csrf import csrf_exempt
from django.views.generic import View from django.views.generic import View
from rest_framework import exceptions, status from rest_framework import exceptions, status
from rest_framework.compat import async_to_sync
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.schemas import DefaultSchema from rest_framework.schemas import DefaultSchema
@ -506,11 +505,6 @@ class APIView(View):
else: else:
handler = self.http_method_not_allowed handler = self.http_method_not_allowed
if asyncio.iscoroutinefunction(handler):
if not async_to_sync:
raise Exception('Async API views are supported only for django>=3.1.')
response = async_to_sync(handler)(request, *args, **kwargs)
else:
response = handler(request, *args, **kwargs) response = handler(request, *args, **kwargs)
except Exception as exc: except Exception as exc:
@ -527,3 +521,37 @@ class APIView(View):
return self.http_method_not_allowed(request, *args, **kwargs) return self.http_method_not_allowed(request, *args, **kwargs)
data = self.metadata_class().determine_metadata(request, self) data = self.metadata_class().determine_metadata(request, self)
return Response(data, status=status.HTTP_200_OK) return Response(data, status=status.HTTP_200_OK)
class AsyncAPIView(APIView):
async def dispatch(self, request, *args, **kwargs):
"""
`.dispatch()` is pretty much the same as Django's regular dispatch,
but with extra hooks for startup, finalize, and exception handling.
"""
self.args = args
self.kwargs = kwargs
request = self.initialize_request(request, *args, **kwargs)
self.request = request
self.headers = self.default_response_headers # deprecate?
try:
self.initial(request, *args, **kwargs)
# Get the appropriate handler method
if request.method.lower() in self.http_method_names:
handler = getattr(self, request.method.lower(),
self.http_method_not_allowed)
else:
handler = self.http_method_not_allowed
if asyncio.iscoroutinefunction(handler):
response = await handler(request, *args, **kwargs)
else:
raise Exception('Async methods should be used on an async view.')
except Exception as exc:
response = self.handle_exception(exc)
self.response = self.finalize_response(request, response, *args, **kwargs)
return self.response

View File

@ -5,11 +5,12 @@ import pytest
from django.test import TestCase from django.test import TestCase
from rest_framework import status from rest_framework import status
from rest_framework.compat import async_to_sync
from rest_framework.decorators import api_view from rest_framework.decorators import api_view
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.settings import APISettings, api_settings from rest_framework.settings import APISettings, api_settings
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
from rest_framework.views import APIView from rest_framework.views import APIView, AsyncAPIView
factory = APIRequestFactory() factory = APIRequestFactory()
@ -24,7 +25,7 @@ class BasicView(APIView):
return Response({'method': 'POST', 'data': request.data}) return Response({'method': 'POST', 'data': request.data})
class BasicAsyncView(APIView): class BasicAsyncView(AsyncAPIView):
async def get(self, request, *args, **kwargs): async def get(self, request, *args, **kwargs):
return Response({'method': 'GET'}) return Response({'method': 'GET'})
@ -139,8 +140,8 @@ class FunctionBasedViewIntegrationTests(TestCase):
@pytest.mark.skipif( @pytest.mark.skipif(
django.VERSION < (3, 1), django.VERSION < (4, 1),
reason="Async view support requires Django 3.1 or higher", reason="Async view support requires Django 4.1 or higher",
) )
class ClassBasedAsyncViewIntegrationTests(TestCase): class ClassBasedAsyncViewIntegrationTests(TestCase):
def setUp(self): def setUp(self):
@ -148,13 +149,13 @@ class ClassBasedAsyncViewIntegrationTests(TestCase):
def test_get_succeeds(self): def test_get_succeeds(self):
request = factory.get('/') request = factory.get('/')
response = self.view(request) response = async_to_sync(self.view)(request)
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.data == {'method': 'GET'} assert response.data == {'method': 'GET'}
def test_post_succeeds(self): def test_post_succeeds(self):
request = factory.post('/', {'test': 'foo'}) request = factory.post('/', {'test': 'foo'})
response = self.view(request) response = async_to_sync(self.view)(request)
expected = { expected = {
'method': 'POST', 'method': 'POST',
'data': {'test': ['foo']} 'data': {'test': ['foo']}
@ -164,7 +165,7 @@ class ClassBasedAsyncViewIntegrationTests(TestCase):
def test_400_parse_error(self): def test_400_parse_error(self):
request = factory.post('/', 'f00bar', content_type='application/json') request = factory.post('/', 'f00bar', content_type='application/json')
response = self.view(request) response = async_to_sync(self.view)(request)
expected = { expected = {
'detail': JSON_ERROR 'detail': JSON_ERROR
} }