From c6cf2d7e83d6cb7a432e20f9ca40b169a7fea761 Mon Sep 17 00:00:00 2001 From: enrico Date: Sun, 18 Sep 2022 19:50:02 -0400 Subject: [PATCH] Fix for options and authentication --- rest_framework/compat.py | 3 ++- rest_framework/views.py | 20 ++++++++++---- tests/test_views.py | 57 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 74 insertions(+), 6 deletions(-) diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 5924d4651..9cb5b76f3 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -43,9 +43,10 @@ except ImportError: # async_to_sync is required for async view support if django.VERSION >= (4, 1): - from asgiref.sync import async_to_sync + from asgiref.sync import async_to_sync, sync_to_async else: async_to_sync = None + sync_to_async = None # coreschema is optional diff --git a/rest_framework/views.py b/rest_framework/views.py index 9bb9a61ac..180fdfd29 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -11,6 +11,7 @@ from django.utils.encoding import smart_str from django.views.decorators.csrf import csrf_exempt from django.views.generic import View +from rest_framework.compat import sync_to_async from rest_framework import exceptions, status from rest_framework.request import Request from rest_framework.response import Response @@ -524,7 +525,7 @@ class APIView(View): self.headers = self.default_response_headers # deprecate? try: - self.initial(request, *args, **kwargs) + sync_to_async(self.initial)(request, *args, **kwargs) # Get the appropriate handler method if request.method.lower() in self.http_method_names: @@ -555,7 +556,16 @@ class APIView(View): """ Handler method for HTTP 'OPTIONS' request. """ - if self.metadata_class is None: - return self.http_method_not_allowed(request, *args, **kwargs) - data = self.metadata_class().determine_metadata(request, self) - return Response(data, status=status.HTTP_200_OK) + def func(): + if self.metadata_class is None: + return self.http_method_not_allowed(request, *args, **kwargs) + data = self.metadata_class().determine_metadata(request, self) + return Response(data, status=status.HTTP_200_OK) + + if hasattr(self, 'view_is_async') and self.view_is_async: + async def handler(): + return func() + else: + def handler(): + return func() + return handler() diff --git a/tests/test_views.py b/tests/test_views.py index 8aeb7e8a7..49fdbe476 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -3,6 +3,7 @@ import copy import django import pytest from django.test import TestCase +from django.contrib.auth.models import User from rest_framework import status from rest_framework.compat import async_to_sync @@ -101,6 +102,15 @@ class ClassBasedViewIntegrationTests(TestCase): assert response.status_code == status.HTTP_200_OK assert response.data == {'method': 'GET'} + def test_logged_in_get_succeeds(self): + user = User.objects.create_user('user', 'user@example.com', 'password') + request = factory.get('/') + del user.is_active + request.user = user + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'method': 'GET'} + def test_post_succeeds(self): request = factory.post('/', {'test': 'foo'}) response = self.view(request) @@ -111,6 +121,11 @@ class ClassBasedViewIntegrationTests(TestCase): assert response.status_code == status.HTTP_200_OK assert response.data == expected + def test_options_succeeds(self): + request = factory.options('/') + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + def test_400_parse_error(self): request = factory.post('/', 'f00bar', content_type='application/json') response = self.view(request) @@ -131,6 +146,15 @@ class FunctionBasedViewIntegrationTests(TestCase): assert response.status_code == status.HTTP_200_OK assert response.data == {'method': 'GET'} + def test_logged_in_get_succeeds(self): + user = User.objects.create_user('user', 'user@example.com', 'password') + request = factory.get('/') + del user.is_active + request.user = user + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'method': 'GET'} + def test_post_succeeds(self): request = factory.post('/', {'test': 'foo'}) response = self.view(request) @@ -141,6 +165,11 @@ class FunctionBasedViewIntegrationTests(TestCase): assert response.status_code == status.HTTP_200_OK assert response.data == expected + def test_options_succeeds(self): + request = factory.options('/') + response = self.view(request) + assert response.status_code == status.HTTP_200_OK + def test_400_parse_error(self): request = factory.post('/', 'f00bar', content_type='application/json') response = self.view(request) @@ -165,6 +194,15 @@ class ClassBasedAsyncViewIntegrationTests(TestCase): assert response.status_code == status.HTTP_200_OK assert response.data == {'method': 'GET'} + def test_logged_in_get_succeeds(self): + user = User.objects.create_user('user', 'user@example.com', 'password') + request = factory.get('/') + del user.is_active + request.user = user + response = async_to_sync(self.view)(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'method': 'GET'} + def test_post_succeeds(self): request = factory.post('/', {'test': 'foo'}) response = async_to_sync(self.view)(request) @@ -175,6 +213,11 @@ class ClassBasedAsyncViewIntegrationTests(TestCase): assert response.status_code == status.HTTP_200_OK assert response.data == expected + def test_options_succeeds(self): + request = factory.options('/') + response = async_to_sync(self.view)(request) + assert response.status_code == status.HTTP_200_OK + def test_400_parse_error(self): request = factory.post('/', 'f00bar', content_type='application/json') response = async_to_sync(self.view)(request) @@ -199,6 +242,15 @@ class FunctionBasedAsyncViewIntegrationTests(TestCase): assert response.status_code == status.HTTP_200_OK assert response.data == {'method': 'GET'} + def test_logged_in_get_succeeds(self): + user = User.objects.create_user('user', 'user@example.com', 'password') + request = factory.get('/') + del user.is_active + request.user = user + response = async_to_sync(self.view)(request) + assert response.status_code == status.HTTP_200_OK + assert response.data == {'method': 'GET'} + def test_post_succeeds(self): request = factory.post('/', {'test': 'foo'}) response = async_to_sync(self.view)(request) @@ -209,6 +261,11 @@ class FunctionBasedAsyncViewIntegrationTests(TestCase): assert response.status_code == status.HTTP_200_OK assert response.data == expected + def test_options_succeeds(self): + request = factory.options('/') + response = async_to_sync(self.view)(request) + assert response.status_code == status.HTTP_200_OK + def test_400_parse_error(self): request = factory.post('/', 'f00bar', content_type='application/json') response = async_to_sync(self.view)(request)