diff --git a/rest_framework/decorators.py b/rest_framework/decorators.py index 3b572c09e..06d42ec09 100644 --- a/rest_framework/decorators.py +++ b/rest_framework/decorators.py @@ -6,6 +6,7 @@ There are also various decorators for setting the API policies on function based views, as well as the `@action` decorator, which is used to annotate methods on viewsets that should be included by routers. """ +import asyncio import types from django.forms.utils import pretty_name @@ -46,11 +47,19 @@ def api_view(http_method_names=None): allowed_methods = set(http_method_names) | {'options'} WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods] - def handler(self, *args, **kwargs): + def sync_handler(self, *args, **kwargs): return func(*args, **kwargs) + async def async_handler(self, *args, **kwargs): + return await func(*args, **kwargs) + + view_is_async = asyncio.iscoroutinefunction(func) + for method in http_method_names: - setattr(WrappedAPIView, method.lower(), handler) + if view_is_async: + setattr(WrappedAPIView, method.lower(), async_handler) + else: + setattr(WrappedAPIView, method.lower(), sync_handler) WrappedAPIView.__name__ = func.__name__ WrappedAPIView.__module__ = func.__module__ diff --git a/rest_framework/views.py b/rest_framework/views.py index 149ca319f..9bb9a61ac 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -484,7 +484,7 @@ class APIView(View): # be overridden. def sync_dispatch(self, request, *args, **kwargs): """ - `.dispatch()` is pretty much the same as Django's regular dispatch, + `.sync_dispatch()` is pretty much the same as Django's regular dispatch, but with extra hooks for startup, finalize, and exception handling. """ self.args = args @@ -513,8 +513,9 @@ class APIView(View): async def async_dispatch(self, request, *args, **kwargs): """ - `.dispatch()` is pretty much the same as Django's regular dispatch, - but with extra hooks for startup, finalize, and exception handling. + `.async_dispatch()` is pretty much the same as Django's regular dispatch, + except for awaiting the handler function and with extra hooks for startup, + finalize, and exception handling. """ self.args = args self.kwargs = kwargs @@ -541,6 +542,10 @@ class APIView(View): return self.response def dispatch(self, request, *args, **kwargs): + """ + Dispatch checks if the view is async or not and uses the respective + async or sync dispatch method. + """ if hasattr(self, 'view_is_async') and self.view_is_async: return self.async_dispatch(request, *args, **kwargs) else: diff --git a/tests/test_views.py b/tests/test_views.py index d46b71fbb..8aeb7e8a7 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -45,6 +45,18 @@ def basic_view(request): return Response({'method': 'PATCH', 'data': request.data}) +@api_view(['GET', 'POST', 'PUT', 'PATCH']) +async def basic_async_view(request): + if request.method == 'GET': + return Response({'method': 'GET'}) + elif request.method == 'POST': + return Response({'method': 'POST', 'data': request.data}) + elif request.method == 'PUT': + return Response({'method': 'PUT', 'data': request.data}) + elif request.method == 'PATCH': + return Response({'method': 'PATCH', 'data': request.data}) + + class ErrorView(APIView): def get(self, request, *args, **kwargs): raise Exception @@ -173,6 +185,40 @@ class ClassBasedAsyncViewIntegrationTests(TestCase): assert sanitise_json_error(response.data) == expected +@pytest.mark.skipif( + django.VERSION < (4, 1), + reason="Async view support requires Django 4.1 or higher", +) +class FunctionBasedAsyncViewIntegrationTests(TestCase): + def setUp(self): + self.view = basic_async_view + + def test_get_succeeds(self): + request = factory.get('/') + response = async_to_sync(self.view)(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'method': 'GET'} + + def test_post_succeeds(self): + request = factory.post('/', {'test': 'foo'}) + response = async_to_sync(self.view)(request) + expected = { + 'method': 'POST', + 'data': {'test': ['foo']} + } + assert response.status_code == status.HTTP_200_OK + assert response.data == expected + + def test_400_parse_error(self): + request = factory.post('/', 'f00bar', content_type='application/json') + response = async_to_sync(self.view)(request) + expected = { + 'detail': JSON_ERROR + } + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert sanitise_json_error(response.data) == expected + + class TestCustomExceptionHandler(TestCase): def setUp(self): self.DEFAULT_HANDLER = api_settings.EXCEPTION_HANDLER