diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 080816094..5924d4651 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -41,8 +41,8 @@ except ImportError: uritemplate = None -# async_to_sync is required for async View support -if django.VERSION >= (3, 1): +# async_to_sync is required for async view support +if django.VERSION >= (4, 1): from asgiref.sync import async_to_sync else: async_to_sync = None diff --git a/rest_framework/views.py b/rest_framework/views.py index 96591d591..f30da98d3 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -14,7 +14,6 @@ from django.views.decorators.csrf import csrf_exempt from django.views.generic import View from rest_framework import exceptions, status -from rest_framework.compat import async_to_sync from rest_framework.request import Request from rest_framework.response import Response from rest_framework.schemas import DefaultSchema @@ -506,12 +505,7 @@ class APIView(View): else: handler = self.http_method_not_allowed - if asyncio.iscoroutinefunction(handler): - if not async_to_sync: - raise Exception('Async API views are supported only for django>=3.1.') - response = async_to_sync(handler)(request, *args, **kwargs) - else: - response = handler(request, *args, **kwargs) + response = handler(request, *args, **kwargs) except Exception as exc: response = self.handle_exception(exc) @@ -527,3 +521,37 @@ class APIView(View): return self.http_method_not_allowed(request, *args, **kwargs) data = self.metadata_class().determine_metadata(request, self) return Response(data, status=status.HTTP_200_OK) + + +class AsyncAPIView(APIView): + async def dispatch(self, request, *args, **kwargs): + """ + `.dispatch()` is pretty much the same as Django's regular dispatch, + but with extra hooks for startup, finalize, and exception handling. + """ + self.args = args + self.kwargs = kwargs + request = self.initialize_request(request, *args, **kwargs) + self.request = request + self.headers = self.default_response_headers # deprecate? + + try: + self.initial(request, *args, **kwargs) + + # Get the appropriate handler method + if request.method.lower() in self.http_method_names: + handler = getattr(self, request.method.lower(), + self.http_method_not_allowed) + else: + handler = self.http_method_not_allowed + + if asyncio.iscoroutinefunction(handler): + response = await handler(request, *args, **kwargs) + else: + raise Exception('Async methods should be used on an async view.') + + except Exception as exc: + response = self.handle_exception(exc) + + self.response = self.finalize_response(request, response, *args, **kwargs) + return self.response diff --git a/tests/test_views.py b/tests/test_views.py index 61d073270..64fceb0d6 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -5,11 +5,12 @@ import pytest from django.test import TestCase from rest_framework import status +from rest_framework.compat import async_to_sync from rest_framework.decorators import api_view from rest_framework.response import Response from rest_framework.settings import APISettings, api_settings from rest_framework.test import APIRequestFactory -from rest_framework.views import APIView +from rest_framework.views import APIView, AsyncAPIView factory = APIRequestFactory() @@ -24,7 +25,7 @@ class BasicView(APIView): return Response({'method': 'POST', 'data': request.data}) -class BasicAsyncView(APIView): +class BasicAsyncView(AsyncAPIView): async def get(self, request, *args, **kwargs): return Response({'method': 'GET'}) @@ -139,8 +140,8 @@ class FunctionBasedViewIntegrationTests(TestCase): @pytest.mark.skipif( - django.VERSION < (3, 1), - reason="Async view support requires Django 3.1 or higher", + django.VERSION < (4, 1), + reason="Async view support requires Django 4.1 or higher", ) class ClassBasedAsyncViewIntegrationTests(TestCase): def setUp(self): @@ -148,13 +149,13 @@ class ClassBasedAsyncViewIntegrationTests(TestCase): def test_get_succeeds(self): request = factory.get('/') - response = self.view(request) + response = async_to_sync(self.view)(request) assert response.status_code == status.HTTP_200_OK assert response.data == {'method': 'GET'} def test_post_succeeds(self): request = factory.post('/', {'test': 'foo'}) - response = self.view(request) + response = async_to_sync(self.view)(request) expected = { 'method': 'POST', 'data': {'test': ['foo']} @@ -164,7 +165,7 @@ class ClassBasedAsyncViewIntegrationTests(TestCase): def test_400_parse_error(self): request = factory.post('/', 'f00bar', content_type='application/json') - response = self.view(request) + response = async_to_sync(self.view)(request) expected = { 'detail': JSON_ERROR }