Fix for options and authentication

This commit is contained in:
enrico 2022-09-18 19:50:02 -04:00
parent f16767f2e5
commit c6cf2d7e83
3 changed files with 74 additions and 6 deletions

View File

@ -43,9 +43,10 @@ except ImportError:
# async_to_sync is required for async view support # async_to_sync is required for async view support
if django.VERSION >= (4, 1): if django.VERSION >= (4, 1):
from asgiref.sync import async_to_sync from asgiref.sync import async_to_sync, sync_to_async
else: else:
async_to_sync = None async_to_sync = None
sync_to_async = None
# coreschema is optional # coreschema is optional

View File

@ -11,6 +11,7 @@ from django.utils.encoding import smart_str
from django.views.decorators.csrf import csrf_exempt from django.views.decorators.csrf import csrf_exempt
from django.views.generic import View from django.views.generic import View
from rest_framework.compat import sync_to_async
from rest_framework import exceptions, status from rest_framework import exceptions, status
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.response import Response from rest_framework.response import Response
@ -524,7 +525,7 @@ class APIView(View):
self.headers = self.default_response_headers # deprecate? self.headers = self.default_response_headers # deprecate?
try: try:
self.initial(request, *args, **kwargs) sync_to_async(self.initial)(request, *args, **kwargs)
# Get the appropriate handler method # Get the appropriate handler method
if request.method.lower() in self.http_method_names: if request.method.lower() in self.http_method_names:
@ -555,7 +556,16 @@ class APIView(View):
""" """
Handler method for HTTP 'OPTIONS' request. Handler method for HTTP 'OPTIONS' request.
""" """
def func():
if self.metadata_class is None: if self.metadata_class is None:
return self.http_method_not_allowed(request, *args, **kwargs) return self.http_method_not_allowed(request, *args, **kwargs)
data = self.metadata_class().determine_metadata(request, self) data = self.metadata_class().determine_metadata(request, self)
return Response(data, status=status.HTTP_200_OK) return Response(data, status=status.HTTP_200_OK)
if hasattr(self, 'view_is_async') and self.view_is_async:
async def handler():
return func()
else:
def handler():
return func()
return handler()

View File

@ -3,6 +3,7 @@ import copy
import django import django
import pytest import pytest
from django.test import TestCase from django.test import TestCase
from django.contrib.auth.models import User
from rest_framework import status from rest_framework import status
from rest_framework.compat import async_to_sync from rest_framework.compat import async_to_sync
@ -101,6 +102,15 @@ class ClassBasedViewIntegrationTests(TestCase):
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.data == {'method': 'GET'} assert response.data == {'method': 'GET'}
def test_logged_in_get_succeeds(self):
user = User.objects.create_user('user', 'user@example.com', 'password')
request = factory.get('/')
del user.is_active
request.user = user
response = self.view(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {'method': 'GET'}
def test_post_succeeds(self): def test_post_succeeds(self):
request = factory.post('/', {'test': 'foo'}) request = factory.post('/', {'test': 'foo'})
response = self.view(request) response = self.view(request)
@ -111,6 +121,11 @@ class ClassBasedViewIntegrationTests(TestCase):
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.data == expected assert response.data == expected
def test_options_succeeds(self):
request = factory.options('/')
response = self.view(request)
assert response.status_code == status.HTTP_200_OK
def test_400_parse_error(self): def test_400_parse_error(self):
request = factory.post('/', 'f00bar', content_type='application/json') request = factory.post('/', 'f00bar', content_type='application/json')
response = self.view(request) response = self.view(request)
@ -131,6 +146,15 @@ class FunctionBasedViewIntegrationTests(TestCase):
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.data == {'method': 'GET'} assert response.data == {'method': 'GET'}
def test_logged_in_get_succeeds(self):
user = User.objects.create_user('user', 'user@example.com', 'password')
request = factory.get('/')
del user.is_active
request.user = user
response = self.view(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {'method': 'GET'}
def test_post_succeeds(self): def test_post_succeeds(self):
request = factory.post('/', {'test': 'foo'}) request = factory.post('/', {'test': 'foo'})
response = self.view(request) response = self.view(request)
@ -141,6 +165,11 @@ class FunctionBasedViewIntegrationTests(TestCase):
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.data == expected assert response.data == expected
def test_options_succeeds(self):
request = factory.options('/')
response = self.view(request)
assert response.status_code == status.HTTP_200_OK
def test_400_parse_error(self): def test_400_parse_error(self):
request = factory.post('/', 'f00bar', content_type='application/json') request = factory.post('/', 'f00bar', content_type='application/json')
response = self.view(request) response = self.view(request)
@ -165,6 +194,15 @@ class ClassBasedAsyncViewIntegrationTests(TestCase):
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.data == {'method': 'GET'} assert response.data == {'method': 'GET'}
def test_logged_in_get_succeeds(self):
user = User.objects.create_user('user', 'user@example.com', 'password')
request = factory.get('/')
del user.is_active
request.user = user
response = async_to_sync(self.view)(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {'method': 'GET'}
def test_post_succeeds(self): def test_post_succeeds(self):
request = factory.post('/', {'test': 'foo'}) request = factory.post('/', {'test': 'foo'})
response = async_to_sync(self.view)(request) response = async_to_sync(self.view)(request)
@ -175,6 +213,11 @@ class ClassBasedAsyncViewIntegrationTests(TestCase):
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.data == expected assert response.data == expected
def test_options_succeeds(self):
request = factory.options('/')
response = async_to_sync(self.view)(request)
assert response.status_code == status.HTTP_200_OK
def test_400_parse_error(self): def test_400_parse_error(self):
request = factory.post('/', 'f00bar', content_type='application/json') request = factory.post('/', 'f00bar', content_type='application/json')
response = async_to_sync(self.view)(request) response = async_to_sync(self.view)(request)
@ -199,6 +242,15 @@ class FunctionBasedAsyncViewIntegrationTests(TestCase):
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.data == {'method': 'GET'} assert response.data == {'method': 'GET'}
def test_logged_in_get_succeeds(self):
user = User.objects.create_user('user', 'user@example.com', 'password')
request = factory.get('/')
del user.is_active
request.user = user
response = async_to_sync(self.view)(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {'method': 'GET'}
def test_post_succeeds(self): def test_post_succeeds(self):
request = factory.post('/', {'test': 'foo'}) request = factory.post('/', {'test': 'foo'})
response = async_to_sync(self.view)(request) response = async_to_sync(self.view)(request)
@ -209,6 +261,11 @@ class FunctionBasedAsyncViewIntegrationTests(TestCase):
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert response.data == expected assert response.data == expected
def test_options_succeeds(self):
request = factory.options('/')
response = async_to_sync(self.view)(request)
assert response.status_code == status.HTTP_200_OK
def test_400_parse_error(self): def test_400_parse_error(self):
request = factory.post('/', 'f00bar', content_type='application/json') request = factory.post('/', 'f00bar', content_type='application/json')
response = async_to_sync(self.view)(request) response = async_to_sync(self.view)(request)