diff --git a/rest_framework/views.py b/rest_framework/views.py index f30da98d3..9a4a75747 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -484,7 +484,7 @@ class APIView(View): # Note: Views are made CSRF exempt from within `as_view` as to prevent # accidental removal of this exemption in cases where `dispatch` needs to # be overridden. - def dispatch(self, request, *args, **kwargs): + def sync_dispatch(self, request, *args, **kwargs): """ `.dispatch()` is pretty much the same as Django's regular dispatch, but with extra hooks for startup, finalize, and exception handling. @@ -513,18 +513,7 @@ class APIView(View): self.response = self.finalize_response(request, response, *args, **kwargs) return self.response - def options(self, request, *args, **kwargs): - """ - Handler method for HTTP 'OPTIONS' request. - """ - if self.metadata_class is None: - return self.http_method_not_allowed(request, *args, **kwargs) - data = self.metadata_class().determine_metadata(request, self) - return Response(data, status=status.HTTP_200_OK) - - -class AsyncAPIView(APIView): - async def dispatch(self, request, *args, **kwargs): + async def async_dispatch(self, request, *args, **kwargs): """ `.dispatch()` is pretty much the same as Django's regular dispatch, but with extra hooks for startup, finalize, and exception handling. @@ -555,3 +544,18 @@ class AsyncAPIView(APIView): self.response = self.finalize_response(request, response, *args, **kwargs) return self.response + + def dispatch(self, request, *args, **kwargs): + if hasattr(self, 'view_is_async') and self.view_is_async: + return self.async_dispatch(request, *args, **kwargs) + else: + return self.sync_dispatch(request, *args, **kwargs) + + def options(self, request, *args, **kwargs): + """ + Handler method for HTTP 'OPTIONS' request. + """ + if self.metadata_class is None: + return self.http_method_not_allowed(request, *args, **kwargs) + data = self.metadata_class().determine_metadata(request, self) + return Response(data, status=status.HTTP_200_OK) diff --git a/tests/test_views.py b/tests/test_views.py index 64fceb0d6..d46b71fbb 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -10,7 +10,7 @@ from rest_framework.decorators import api_view from rest_framework.response import Response from rest_framework.settings import APISettings, api_settings from rest_framework.test import APIRequestFactory -from rest_framework.views import APIView, AsyncAPIView +from rest_framework.views import APIView factory = APIRequestFactory() @@ -25,7 +25,7 @@ class BasicView(APIView): return Response({'method': 'POST', 'data': request.data}) -class BasicAsyncView(AsyncAPIView): +class BasicAsyncView(APIView): async def get(self, request, *args, **kwargs): return Response({'method': 'GET'})