From 79ce07ba9b0d46cb7c0c592ef47e7f834f651df4 Mon Sep 17 00:00:00 2001 From: enrico Date: Fri, 19 Aug 2022 17:03:22 +0800 Subject: [PATCH] Tentative async implementation --- rest_framework/views.py | 7 +++- tests/test_views.py | 79 ++++++++++++++++++++++++++++++++++++++--- 2 files changed, 81 insertions(+), 5 deletions(-) diff --git a/rest_framework/views.py b/rest_framework/views.py index 5b0622069..a90b4d3c0 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -1,6 +1,8 @@ """ Provides an APIView class that is the base of all views in REST framework. """ +import asyncio +from asgiref.sync import async_to_sync from django.conf import settings from django.core.exceptions import PermissionDenied from django.db import connections, models @@ -503,7 +505,10 @@ class APIView(View): else: handler = self.http_method_not_allowed - response = handler(request, *args, **kwargs) + if asyncio.iscoroutinefunction(handler): + response = async_to_sync(handler)(request, *args, **kwargs) + else: + response = handler(request, *args, **kwargs) except Exception as exc: response = self.handle_exception(exc) diff --git a/tests/test_views.py b/tests/test_views.py index 2648c9fb3..da6271d2c 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -22,16 +22,24 @@ class BasicView(APIView): return Response({'method': 'POST', 'data': request.data}) +class BasicAsyncView(APIView): + async def get(self, request, *args, **kwargs): + return Response({'method': 'GET'}) + + async def post(self, request, *args, **kwargs): + return Response({'method': 'POST', 'data': request.data}) + + @api_view(['GET', 'POST', 'PUT', 'PATCH']) def basic_view(request): if request.method == 'GET': - return {'method': 'GET'} + return Response({'method': 'GET'}) elif request.method == 'POST': - return {'method': 'POST', 'data': request.data} + return Response({'method': 'POST', 'data': request.data}) elif request.method == 'PUT': - return {'method': 'PUT', 'data': request.data} + return Response({'method': 'PUT', 'data': request.data}) elif request.method == 'PATCH': - return {'method': 'PATCH', 'data': request.data} + return Response({'method': 'PATCH', 'data': request.data}) class ErrorView(APIView): @@ -72,6 +80,23 @@ class ClassBasedViewIntegrationTests(TestCase): def setUp(self): self.view = BasicView.as_view() + def test_get_succeeds(self): + request = factory.get('/', content_type='application/json') + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'method': 'GET'} + + # def test_post_succeeds(self): + # request = factory.post('/', {"test": "foo"}, content_type='application/json') + # response = self.view(request) + # import pdb; pdb.set_trace() + # expected = { + # 'method': 'POST', + # 'data': {'test': 'foo'} + # } + # assert response.status_code == status.HTTP_200_OK + # assert response.data == expected + def test_400_parse_error(self): request = factory.post('/', 'f00bar', content_type='application/json') response = self.view(request) @@ -86,6 +111,52 @@ class FunctionBasedViewIntegrationTests(TestCase): def setUp(self): self.view = basic_view + def test_get_succeeds(self): + request = factory.get('/', content_type='application/json') + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'method': 'GET'} + + # def test_post_succeeds(self): + # request = factory.post('/', {'test': 'foo'}, content_type='application/json') + # response = self.view(request) + # expected = { + # 'method': 'POST', + # 'data': {'test': 'foo'} + # } + # assert response.status_code == status.HTTP_200_OK + # assert response.data == expected + + def test_400_parse_error(self): + request = factory.post('/', 'f00bar', content_type='application/json') + response = self.view(request) + expected = { + 'detail': JSON_ERROR + } + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert sanitise_json_error(response.data) == expected + + +class ClassBasedAsyncViewIntegrationTests(TestCase): + def setUp(self): + self.view = BasicAsyncView.as_view() + + def test_get_succeeds(self): + request = factory.get('/', content_type='application/json') + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'method': 'GET'} + +# def test_post_succeeds(self): +# request = factory.post('/', {'test': 'foo'}, content_type='application/json') +# response = self.view(request) +# expected = { +# 'method': 'POST', +# 'data': {'test': 'foo'} +# } +# assert response.status_code == status.HTTP_200_OK +# assert response.data == expected + def test_400_parse_error(self): request = factory.post('/', 'f00bar', content_type='application/json') response = self.view(request)