Async view implementation

This commit is contained in:
James Hilliard 2023-05-10 08:51:42 -06:00
parent 4f7e9ed3bb
commit 979bb24bec
9 changed files with 932 additions and 68 deletions

View File

@ -12,6 +12,7 @@ jobs:
runs-on: ubuntu-20.04
strategy:
fail-fast: false
matrix:
python-version:
- '3.6'

View File

@ -217,6 +217,22 @@ You may pass `None` in order to exclude the view from schema generation.
def view(request):
return Response({"message": "Will not appear in schema!"})
# Async Views
When using Django 4.1 and above, REST framework allows you to work with async class and function based views.
For class based views, all handler methods must be async, otherwise Django will raise an exception. For function based views, the function itself must be async.
For example:
class AsyncView(APIView):
async def get(self, request):
return Response({"message": "This is an async class based view."})
@api_view(['GET'])
async def async_view(request):
return Response({"message": "This is an async function based view."})
[cite]: https://reinout.vanrees.org/weblog/2011/08/24/class-based-views-usage.html
[cite2]: http://www.boredomandlaziness.org/2012/05/djangos-cbvs-are-not-mistake-but.html

View File

@ -41,6 +41,17 @@ except ImportError:
uritemplate = None
# async_to_sync is required for async view support
if django.VERSION >= (4, 1):
from asgiref.sync import async_to_sync, iscoroutinefunction, sync_to_async
else:
async_to_sync = None
sync_to_async = None
def iscoroutinefunction(func):
return False
# coreschema is optional
try:
import coreschema

View File

@ -10,6 +10,7 @@ import types
from django.forms.utils import pretty_name
from rest_framework.compat import iscoroutinefunction
from rest_framework.views import APIView
@ -46,8 +47,12 @@ def api_view(http_method_names=None):
allowed_methods = set(http_method_names) | {'options'}
WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods]
def handler(self, *args, **kwargs):
return func(*args, **kwargs)
if iscoroutinefunction(func):
async def handler(self, *args, **kwargs):
return await func(*args, **kwargs)
else:
def handler(self, *args, **kwargs):
return func(*args, **kwargs)
for method in http_method_names:
setattr(WrappedAPIView, method.lower(), handler)

View File

@ -11,6 +11,10 @@ from django.test import override_settings, testcases
from django.test.client import Client as DjangoClient
from django.test.client import ClientHandler
from django.test.client import RequestFactory as DjangoRequestFactory
if django.VERSION >= (4, 1):
from django.test.client import AsyncRequestFactory as DjangoAsyncRequestFactory
from django.utils.encoding import force_bytes
from django.utils.http import urlencode
@ -136,7 +140,7 @@ else:
raise ImproperlyConfigured('coreapi must be installed in order to use CoreAPIClient.')
class APIRequestFactory(DjangoRequestFactory):
class APIRequestFactoryMixin:
renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES
default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT
@ -240,6 +244,15 @@ class APIRequestFactory(DjangoRequestFactory):
return request
class APIRequestFactory(APIRequestFactoryMixin, DjangoRequestFactory):
pass
if django.VERSION >= (4, 1):
class APIAsyncRequestFactory(APIRequestFactoryMixin, DjangoAsyncRequestFactory):
pass
class ForceAuthClientHandler(ClientHandler):
"""
A patched version of ClientHandler that can enforce authentication

View File

@ -6,6 +6,9 @@ import time
from django.core.cache import cache as default_cache
from django.core.exceptions import ImproperlyConfigured
from rest_framework.compat import (
async_to_sync, iscoroutinefunction, sync_to_async
)
from rest_framework.settings import api_settings
@ -64,6 +67,8 @@ class SimpleRateThrottle(BaseThrottle):
cache_format = 'throttle_%(scope)s_%(ident)s'
scope = None
THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
sync_capable = True
async_capable = True
def __init__(self):
if not getattr(self, 'rate', None):
@ -113,23 +118,52 @@ class SimpleRateThrottle(BaseThrottle):
On success calls `throttle_success`.
On failure calls `throttle_failure`.
"""
if self.rate is None:
return True
if getattr(view, 'view_is_async', False):
self.key = self.get_cache_key(request, view)
if self.key is None:
return True
async def func():
if self.rate is None:
return True
self.history = self.cache.get(self.key, [])
self.now = self.timer()
self.key = self.get_cache_key(request, view)
if self.key is None:
return True
# Drop any requests from the history which have now passed the
# throttle duration
while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop()
if len(self.history) >= self.num_requests:
return self.throttle_failure()
return self.throttle_success()
self.history = self.cache.get(self.key, [])
if iscoroutinefunction(self.timer):
self.now = await self.timer()
else:
self.now = await sync_to_async(self.timer)()
# Drop any requests from the history which have now passed the
# throttle duration
while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop()
if len(self.history) >= self.num_requests:
return self.throttle_failure()
return self.throttle_success()
return func()
else:
if self.rate is None:
return True
self.key = self.get_cache_key(request, view)
if self.key is None:
return True
self.history = self.cache.get(self.key, [])
if iscoroutinefunction(self.timer):
self.now = async_to_sync(self.timer)()
else:
self.now = self.timer()
# Drop any requests from the history which have now passed the
# throttle duration
while self.history and self.history[-1] <= self.now - self.duration:
self.history.pop()
if len(self.history) >= self.num_requests:
return self.throttle_failure()
return self.throttle_success()
def throttle_success(self):
"""
@ -210,6 +244,8 @@ class ScopedRateThrottle(SimpleRateThrottle):
user id of the request, and the scope of the view being accessed.
"""
scope_attr = 'throttle_scope'
sync_capable = True
async_capable = True
def __init__(self):
# Override the usual SimpleRateThrottle, because we can't determine
@ -220,17 +256,34 @@ class ScopedRateThrottle(SimpleRateThrottle):
# We can only determine the scope once we're called by the view.
self.scope = getattr(view, self.scope_attr, None)
# If a view does not have a `throttle_scope` always allow the request
if not self.scope:
return True
if getattr(view, 'view_is_async', False):
# Determine the allowed request rate as we normally would during
# the `__init__` call.
self.rate = self.get_rate()
self.num_requests, self.duration = self.parse_rate(self.rate)
async def func(allow_request):
# If a view does not have a `throttle_scope` always allow the request
if not self.scope:
return True
# We can now proceed as normal.
return super().allow_request(request, view)
# Determine the allowed request rate as we normally would during
# the `__init__` call.
self.rate = self.get_rate()
self.num_requests, self.duration = self.parse_rate(self.rate)
# We can now proceed as normal.
return await allow_request(request, view)
return func(super().allow_request)
else:
# If a view does not have a `throttle_scope` always allow the request
if not self.scope:
return True
# Determine the allowed request rate as we normally would during
# the `__init__` call.
self.rate = self.get_rate()
self.num_requests, self.duration = self.parse_rate(self.rate)
# We can now proceed as normal.
return super().allow_request(request, view)
def get_cache_key(self, request, view):
"""

View File

@ -1,6 +1,8 @@
"""
Provides an APIView class that is the base of all views in REST framework.
"""
import asyncio
from django.conf import settings
from django.core.exceptions import PermissionDenied
from django.db import connections, models
@ -12,6 +14,9 @@ from django.views.decorators.csrf import csrf_exempt
from django.views.generic import View
from rest_framework import exceptions, status
from rest_framework.compat import (
async_to_sync, iscoroutinefunction, sync_to_async
)
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.schemas import DefaultSchema
@ -328,13 +333,52 @@ class APIView(View):
Check if the request should be permitted.
Raises an appropriate exception if the request is not permitted.
"""
async_permissions, sync_permissions = [], []
for permission in self.get_permissions():
if not permission.has_permission(request, self):
self.permission_denied(
request,
message=getattr(permission, 'message', None),
code=getattr(permission, 'code', None)
)
if iscoroutinefunction(permission.has_permission):
async_permissions.append(permission)
else:
sync_permissions.append(permission)
async def check_async():
results = await asyncio.gather(
*(permission.has_permission(request, self) for permission in
async_permissions), return_exceptions=True
)
for idx in range(len(async_permissions)):
if isinstance(results[idx], Exception):
raise results[idx]
elif not results[idx]:
self.permission_denied(
request,
message=getattr(async_permissions[idx], "message", None),
code=getattr(async_permissions[idx], "code", None),
)
def check_sync():
for permission in sync_permissions:
if not permission.has_permission(request, self):
self.permission_denied(
request,
message=getattr(permission, 'message', None),
code=getattr(permission, 'code', None)
)
if getattr(self, 'view_is_async', False):
async def func():
if async_permissions:
await check_async()
if sync_permissions:
await sync_to_async(check_sync)()
return func()
else:
if sync_permissions:
check_sync()
if async_permissions:
async_to_sync(check_async)
def check_object_permissions(self, request, obj):
"""
@ -354,21 +398,79 @@ class APIView(View):
Check if request should be throttled.
Raises an appropriate exception if the request is throttled.
"""
throttle_durations = []
async_throttle_durations, sync_throttle_durations = [], []
view_is_async = getattr(self, 'view_is_async', False)
for throttle in self.get_throttles():
if not throttle.allow_request(request, self):
throttle_durations.append(throttle.wait())
throttle_can_sync = getattr(throttle, "sync_capable", True)
throttle_can_async = getattr(throttle, "async_capable", False)
if not throttle_can_sync and not throttle_can_async:
raise RuntimeError(
"Throttle %s must have at least one of "
"sync_capable/async_capable set to True." % throttle.__class__.__name__
)
elif not view_is_async and throttle_can_sync:
throttle_is_async = False
elif iscoroutinefunction(throttle.allow_request):
throttle_is_async = True
else:
throttle_is_async = throttle_can_async
if throttle_is_async:
async_throttle_durations.append(throttle)
else:
sync_throttle_durations.append(throttle)
if throttle_durations:
# Filter out `None` values which may happen in case of config / rate
# changes, see #1438
durations = [
duration for duration in throttle_durations
if duration is not None
]
async def async_throttles():
for throttle in async_throttle_durations:
if not await throttle.allow_request(request, self):
yield throttle.wait()
duration = max(durations, default=None)
self.throttled(request, duration)
def sync_throttles():
for throttle in sync_throttle_durations:
if not throttle.allow_request(request, self):
yield throttle.wait()
if view_is_async:
async def func():
throttle_durations = []
if async_throttle_durations:
throttle_durations.extend([duration async for duration in async_throttles()])
if sync_throttle_durations:
throttle_durations.extend(duration for duration in sync_throttles())
if throttle_durations:
# Filter out `None` values which may happen in case of config / rate
# changes, see #1438
durations = [
duration for duration in throttle_durations
if duration is not None
]
duration = max(durations, default=None)
self.throttled(request, duration)
return func()
else:
throttle_durations = []
if sync_throttle_durations:
throttle_durations.extend(sync_throttles())
if async_throttle_durations:
throttle_durations.extend(async_to_sync(async_throttles)())
if throttle_durations:
# Filter out `None` values which may happen in case of config / rate
# changes, see #1438
durations = [
duration for duration in throttle_durations
if duration is not None
]
duration = max(durations, default=None)
self.throttled(request, duration)
def determine_version(self, request, *args, **kwargs):
"""
@ -410,10 +512,20 @@ class APIView(View):
version, scheme = self.determine_version(request, *args, **kwargs)
request.version, request.versioning_scheme = version, scheme
# Ensure that the incoming request is permitted
self.perform_authentication(request)
self.check_permissions(request)
self.check_throttles(request)
if getattr(self, 'view_is_async', False):
async def func():
# Ensure that the incoming request is permitted
await sync_to_async(self.perform_authentication)(request)
await self.check_permissions(request)
await self.check_throttles(request)
return func()
else:
# Ensure that the incoming request is permitted
self.perform_authentication(request)
self.check_permissions(request)
self.check_throttles(request)
def finalize_response(self, request, response, *args, **kwargs):
"""
@ -469,7 +581,15 @@ class APIView(View):
self.raise_uncaught_exception(exc)
response.exception = True
return response
if getattr(self, 'view_is_async', False):
async def func():
return response
return func()
else:
return response
def raise_uncaught_exception(self, exc):
if settings.DEBUG:
@ -493,23 +613,49 @@ class APIView(View):
self.request = request
self.headers = self.default_response_headers # deprecate?
try:
self.initial(request, *args, **kwargs)
if getattr(self, 'view_is_async', False):
# Get the appropriate handler method
if request.method.lower() in self.http_method_names:
handler = getattr(self, request.method.lower(),
self.http_method_not_allowed)
else:
handler = self.http_method_not_allowed
async def func():
response = handler(request, *args, **kwargs)
try:
await self.initial(request, *args, **kwargs)
except Exception as exc:
response = self.handle_exception(exc)
# Get the appropriate handler method
if request.method.lower() in self.http_method_names:
handler = getattr(self, request.method.lower(),
self.http_method_not_allowed)
else:
handler = self.http_method_not_allowed
self.response = self.finalize_response(request, response, *args, **kwargs)
return self.response
response = await handler(request, *args, **kwargs)
except Exception as exc:
response = await self.handle_exception(exc)
return self.finalize_response(request, response, *args, **kwargs)
self.response = func()
return self.response
else:
try:
self.initial(request, *args, **kwargs)
# Get the appropriate handler method
if request.method.lower() in self.http_method_names:
handler = getattr(self, request.method.lower(),
self.http_method_not_allowed)
else:
handler = self.http_method_not_allowed
response = handler(request, *args, **kwargs)
except Exception as exc:
response = self.handle_exception(exc)
self.response = self.finalize_response(request, response, *args, **kwargs)
return self.response
def options(self, request, *args, **kwargs):
"""
@ -518,4 +664,12 @@ class APIView(View):
if self.metadata_class is None:
return self.http_method_not_allowed(request, *args, **kwargs)
data = self.metadata_class().determine_metadata(request, self)
return Response(data, status=status.HTTP_200_OK)
if getattr(self, 'view_is_async', False):
async def func():
return Response(data, status=status.HTTP_200_OK)
return func()
else:
return Response(data, status=status.HTTP_200_OK)

View File

@ -1,7 +1,7 @@
"""
Tests for the throttling implementations in the permissions module.
"""
import django
import pytest
from django.contrib.auth.models import User
from django.core.cache import cache
@ -9,10 +9,15 @@ from django.core.exceptions import ImproperlyConfigured
from django.http import HttpRequest
from django.test import TestCase
from rest_framework.compat import async_to_sync
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.settings import api_settings
from rest_framework.test import APIRequestFactory, force_authenticate
if django.VERSION >= (4, 1):
from rest_framework.test import APIAsyncRequestFactory
from rest_framework.throttling import (
AnonRateThrottle, BaseThrottle, ScopedRateThrottle, SimpleRateThrottle,
UserRateThrottle
@ -43,6 +48,14 @@ class NonTimeThrottle(BaseThrottle):
return False
class NonTimeAsyncThrottle(BaseThrottle):
def allow_request(self, request, view):
if not hasattr(self.__class__, 'called'):
self.__class__.called = True
return True
return False
class MockView_DoubleThrottling(APIView):
throttle_classes = (User3SecRateThrottle, User6MinRateThrottle,)
@ -50,6 +63,13 @@ class MockView_DoubleThrottling(APIView):
return Response('foo')
class MockAsyncView_DoubleThrottling(APIView):
throttle_classes = (User3SecRateThrottle, User6MinRateThrottle,)
async def get(self, request):
return Response('foo')
class MockView(APIView):
throttle_classes = (User3SecRateThrottle,)
@ -57,6 +77,13 @@ class MockView(APIView):
return Response('foo')
class MockAsyncView(APIView):
throttle_classes = (User3SecRateThrottle,)
async def get(self, request):
return Response('foo')
class MockView_MinuteThrottling(APIView):
throttle_classes = (User3MinRateThrottle,)
@ -64,6 +91,13 @@ class MockView_MinuteThrottling(APIView):
return Response('foo')
class MockAsyncView_MinuteThrottling(APIView):
throttle_classes = (User3MinRateThrottle,)
async def get(self, request):
return Response('foo')
class MockView_NonTimeThrottling(APIView):
throttle_classes = (NonTimeThrottle,)
@ -71,6 +105,13 @@ class MockView_NonTimeThrottling(APIView):
return Response('foo')
class MockAsyncView_NonTimeThrottling(APIView):
throttle_classes = (NonTimeAsyncThrottle,)
async def get(self, request):
return Response('foo')
class ThrottlingTests(TestCase):
def setUp(self):
"""
@ -252,12 +293,198 @@ class ThrottlingTests(TestCase):
self.assertFalse('Retry-After' in response)
@pytest.mark.skipif(
django.VERSION < (4, 1),
reason="Async view support requires Django 4.1 or higher",
)
class AsyncThrottlingTests(TestCase):
def setUp(self):
"""
Reset the cache so that no throttles will be active
"""
cache.clear()
self.factory = APIAsyncRequestFactory()
def test_requests_are_throttled(self):
"""
Ensure request rate is limited
"""
request = self.factory.get('/')
for dummy in range(4):
response = async_to_sync(MockAsyncView.as_view())(request)
assert response.status_code == 429
def set_throttle_timer(self, view, value):
"""
Explicitly set the timer, overriding time.time()
"""
for cls in view.throttle_classes:
cls.timer = lambda self: value
def test_request_throttling_expires(self):
"""
Ensure request rate is limited for a limited duration only
"""
self.set_throttle_timer(MockAsyncView, 0)
request = self.factory.get('/')
for dummy in range(4):
response = async_to_sync(MockAsyncView.as_view())(request)
assert response.status_code == 429
# Advance the timer by one second
self.set_throttle_timer(MockAsyncView, 1)
response = async_to_sync(MockAsyncView.as_view())(request)
assert response.status_code == 200
async def ensure_is_throttled(self, view, expect):
request = self.factory.get('/')
request.user = await User.objects.acreate(username='a')
for dummy in range(3):
await view.as_view()(request)
request.user = await User.objects.acreate(username='b')
response = await view.as_view()(request)
assert response.status_code == expect
def test_request_throttling_is_per_user(self):
"""
Ensure request rate is only limited per user, not globally for
PerUserThrottles
"""
async_to_sync(self.ensure_is_throttled)(MockAsyncView, 200)
def test_request_throttling_multiple_throttles(self):
"""
Ensure all throttle classes see each request even when the request is
already being throttled
"""
self.set_throttle_timer(MockAsyncView_DoubleThrottling, 0)
request = self.factory.get('/')
for dummy in range(4):
response = async_to_sync(MockAsyncView_DoubleThrottling.as_view())(request)
assert response.status_code == 429
assert int(response['retry-after']) == 1
# At this point our client made 4 requests (one was throttled) in a
# second. If we advance the timer by one additional second, the client
# should be allowed to make 2 more before being throttled by the 2nd
# throttle class, which has a limit of 6 per minute.
self.set_throttle_timer(MockAsyncView_DoubleThrottling, 1)
for dummy in range(2):
response = async_to_sync(MockAsyncView_DoubleThrottling.as_view())(request)
assert response.status_code == 200
response = async_to_sync(MockAsyncView_DoubleThrottling.as_view())(request)
assert response.status_code == 429
assert int(response['retry-after']) == 59
# Just to make sure check again after two more seconds.
self.set_throttle_timer(MockAsyncView_DoubleThrottling, 2)
response = async_to_sync(MockAsyncView_DoubleThrottling.as_view())(request)
assert response.status_code == 429
assert int(response['retry-after']) == 58
def test_throttle_rate_change_negative(self):
self.set_throttle_timer(MockAsyncView_DoubleThrottling, 0)
request = self.factory.get('/')
for dummy in range(24):
response = async_to_sync(MockAsyncView_DoubleThrottling.as_view())(request)
assert response.status_code == 429
assert int(response['retry-after']) == 60
previous_rate = User3SecRateThrottle.rate
try:
User3SecRateThrottle.rate = '1/sec'
for dummy in range(24):
response = async_to_sync(MockAsyncView_DoubleThrottling.as_view())(request)
assert response.status_code == 429
assert int(response['retry-after']) == 60
finally:
# reset
User3SecRateThrottle.rate = previous_rate
async def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers):
"""
Ensure the response returns an Retry-After field with status and next attributes
set properly.
"""
request = self.factory.get('/')
for timer, expect in expected_headers:
self.set_throttle_timer(view, timer)
response = await view.as_view()(request)
if expect is not None:
assert response['Retry-After'] == expect
else:
assert not'Retry-After' in response
def test_seconds_fields(self):
"""
Ensure for second based throttles.
"""
async_to_sync(self.ensure_response_header_contains_proper_throttle_field)(
MockAsyncView, (
(0, None),
(0, None),
(0, None),
(0, '1')
)
)
def test_minutes_fields(self):
"""
Ensure for minute based throttles.
"""
async_to_sync(self.ensure_response_header_contains_proper_throttle_field)(
MockAsyncView_MinuteThrottling, (
(0, None),
(0, None),
(0, None),
(0, '60')
)
)
def test_next_rate_remains_constant_if_followed(self):
"""
If a client follows the recommended next request rate,
the throttling rate should stay constant.
"""
async_to_sync(self.ensure_response_header_contains_proper_throttle_field)(
MockAsyncView_MinuteThrottling, (
(0, None),
(20, None),
(40, None),
(60, None),
(80, None)
)
)
def test_non_time_throttle(self):
"""
Ensure for second based throttles.
"""
request = self.factory.get('/')
self.assertFalse(hasattr(MockAsyncView_NonTimeThrottling.throttle_classes[0], 'called'))
response = async_to_sync(MockAsyncView_NonTimeThrottling.as_view())(request)
self.assertFalse('Retry-After' in response)
self.assertTrue(MockAsyncView_NonTimeThrottling.throttle_classes[0].called)
response = async_to_sync(MockAsyncView_NonTimeThrottling.as_view())(request)
self.assertFalse('Retry-After' in response)
class ScopedRateThrottleTests(TestCase):
"""
Tests for ScopedRateThrottle.
"""
def setUp(self):
cache.clear()
self.throttle = ScopedRateThrottle()
class XYScopedRateThrottle(ScopedRateThrottle):
@ -372,6 +599,131 @@ class ScopedRateThrottleTests(TestCase):
assert cache_key == 'throttle_user_%s' % user.pk
@pytest.mark.skipif(
django.VERSION < (4, 1),
reason="Async view support requires Django 4.1 or higher",
)
class AsyncScopedRateThrottleTests(TestCase):
"""
Tests for ScopedRateThrottle.
"""
def setUp(self):
cache.clear()
self.throttle = ScopedRateThrottle()
class XYScopedRateThrottle(ScopedRateThrottle):
TIMER_SECONDS = 0
THROTTLE_RATES = {'x': '3/min', 'y': '1/min'}
async def timer(self):
return self.TIMER_SECONDS
class XView(APIView):
throttle_classes = (XYScopedRateThrottle,)
throttle_scope = 'x'
async def get(self, request):
return Response('x')
class YView(APIView):
throttle_classes = (XYScopedRateThrottle,)
throttle_scope = 'y'
async def get(self, request):
return Response('y')
class UnscopedView(APIView):
throttle_classes = (XYScopedRateThrottle,)
async def get(self, request):
return Response('y')
self.throttle_class = XYScopedRateThrottle
self.factory = APIAsyncRequestFactory()
self.x_view = XView.as_view()
self.y_view = YView.as_view()
self.unscoped_view = UnscopedView.as_view()
def increment_timer(self, seconds=1):
self.throttle_class.TIMER_SECONDS += seconds
def test_scoped_rate_throttle(self):
request = self.factory.get('/')
# Should be able to hit x view 3 times per minute.
response = async_to_sync(self.x_view)(request)
assert response.status_code == 200
self.increment_timer()
response = async_to_sync(self.x_view)(request)
assert response.status_code == 200
self.increment_timer()
response = async_to_sync(self.x_view)(request)
assert response.status_code == 200
self.increment_timer()
response = async_to_sync(self.x_view)(request)
assert response.status_code == 429
# Should be able to hit y view 1 time per minute.
self.increment_timer()
response = async_to_sync(self.y_view)(request)
assert response.status_code == 200
self.increment_timer()
response = async_to_sync(self.y_view)(request)
assert response.status_code == 429
# Ensure throttles properly reset by advancing the rest of the minute
self.increment_timer(55)
# Should still be able to hit x view 3 times per minute.
response = async_to_sync(self.x_view)(request)
assert response.status_code == 200
self.increment_timer()
response = async_to_sync(self.x_view)(request)
assert response.status_code == 200
self.increment_timer()
response = async_to_sync(self.x_view)(request)
assert response.status_code == 200
self.increment_timer()
response = async_to_sync(self.x_view)(request)
assert response.status_code == 429
# Should still be able to hit y view 1 time per minute.
self.increment_timer()
response = async_to_sync(self.y_view)(request)
assert response.status_code == 200
self.increment_timer()
response = async_to_sync(self.y_view)(request)
assert response.status_code == 429
def test_unscoped_view_not_throttled(self):
request = self.factory.get('/')
for idx in range(10):
self.increment_timer()
response = async_to_sync(self.unscoped_view)(request)
assert response.status_code == 200
def test_get_cache_key_returns_correct_key_if_user_is_authenticated(self):
class DummyView:
throttle_scope = 'user'
request = Request(HttpRequest())
user = User.objects.create(username='test')
force_authenticate(request, user)
request.user = user
self.throttle.allow_request(request, DummyView())
cache_key = self.throttle.get_cache_key(request, view=DummyView())
assert cache_key == 'throttle_user_%s' % user.pk
class XffTestingBase(TestCase):
def setUp(self):
@ -400,6 +752,34 @@ class XffTestingBase(TestCase):
setattr(api_settings, 'NUM_PROXIES', num_proxies)
class AsyncXffTestingBase(TestCase):
def setUp(self):
class Throttle(ScopedRateThrottle):
THROTTLE_RATES = {'test_limit': '1/day'}
TIMER_SECONDS = 0
async def timer(self):
return self.TIMER_SECONDS
class View(APIView):
throttle_classes = (Throttle,)
throttle_scope = 'test_limit'
async def get(self, request):
return Response('test_limit')
cache.clear()
self.throttle = Throttle()
self.view = View.as_view()
self.request = APIAsyncRequestFactory().get('/some_uri')
self.request.META['REMOTE_ADDR'] = '3.3.3.3'
self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 1.1.1.1, 2.2.2.2'
def config_proxy(self, num_proxies):
setattr(api_settings, 'NUM_PROXIES', num_proxies)
class IdWithXffBasicTests(XffTestingBase):
def test_accepts_request_under_limit(self):
self.config_proxy(0)
@ -411,6 +791,21 @@ class IdWithXffBasicTests(XffTestingBase):
assert self.view(self.request).status_code == 429
@pytest.mark.skipif(
django.VERSION < (4, 1),
reason="Async view support requires Django 4.1 or higher",
)
class AsyncIdWithXffBasicTests(AsyncXffTestingBase):
def test_accepts_request_under_limit(self):
self.config_proxy(0)
assert async_to_sync(self.view)(self.request).status_code == 200
def test_denies_request_over_limit(self):
self.config_proxy(0)
async_to_sync(self.view)(self.request)
assert async_to_sync(self.view)(self.request).status_code == 429
class XffSpoofingTests(XffTestingBase):
def test_xff_spoofing_doesnt_change_machine_id_with_one_app_proxy(self):
self.config_proxy(1)
@ -425,6 +820,24 @@ class XffSpoofingTests(XffTestingBase):
assert self.view(self.request).status_code == 429
@pytest.mark.skipif(
django.VERSION < (4, 1),
reason="Async view support requires Django 4.1 or higher",
)
class AsyncXffSpoofingTests(AsyncXffTestingBase):
def test_xff_spoofing_doesnt_change_machine_id_with_one_app_proxy(self):
self.config_proxy(1)
async_to_sync(self.view)(self.request)
self.request.META['HTTP_X_FORWARDED_FOR'] = '4.4.4.4, 5.5.5.5, 2.2.2.2'
assert async_to_sync(self.view)(self.request).status_code == 429
def test_xff_spoofing_doesnt_change_machine_id_with_two_app_proxies(self):
self.config_proxy(2)
async_to_sync(self.view)(self.request)
self.request.META['HTTP_X_FORWARDED_FOR'] = '4.4.4.4, 1.1.1.1, 2.2.2.2'
assert async_to_sync(self.view)(self.request).status_code == 429
class XffUniqueMachinesTest(XffTestingBase):
def test_unique_clients_are_counted_independently_with_one_proxy(self):
self.config_proxy(1)
@ -439,6 +852,24 @@ class XffUniqueMachinesTest(XffTestingBase):
assert self.view(self.request).status_code == 200
@pytest.mark.skipif(
django.VERSION < (4, 1),
reason="Async view support requires Django 4.1 or higher",
)
class AsyncXffUniqueMachinesTest(AsyncXffTestingBase):
def test_unique_clients_are_counted_independently_with_one_proxy(self):
self.config_proxy(1)
async_to_sync(self.view)(self.request)
self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 1.1.1.1, 7.7.7.7'
assert async_to_sync(self.view)(self.request).status_code == 200
def test_unique_clients_are_counted_independently_with_two_proxies(self):
self.config_proxy(2)
async_to_sync(self.view)(self.request)
self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 7.7.7.7, 2.2.2.2'
assert async_to_sync(self.view)(self.request).status_code == 200
class BaseThrottleTests(TestCase):
def test_allow_request_raises_not_implemented_error(self):

View File

@ -1,8 +1,12 @@
import copy
import django
import pytest
from django.contrib.auth.models import User
from django.test import TestCase
from rest_framework import status
from rest_framework.compat import async_to_sync
from rest_framework.decorators import api_view
from rest_framework.response import Response
from rest_framework.settings import APISettings, api_settings
@ -22,16 +26,36 @@ class BasicView(APIView):
return Response({'method': 'POST', 'data': request.data})
class BasicAsyncView(APIView):
async def get(self, request, *args, **kwargs):
return Response({'method': 'GET'})
async def post(self, request, *args, **kwargs):
return Response({'method': 'POST', 'data': request.data})
@api_view(['GET', 'POST', 'PUT', 'PATCH'])
def basic_view(request):
if request.method == 'GET':
return {'method': 'GET'}
return Response({'method': 'GET'})
elif request.method == 'POST':
return {'method': 'POST', 'data': request.data}
return Response({'method': 'POST', 'data': request.data})
elif request.method == 'PUT':
return {'method': 'PUT', 'data': request.data}
return Response({'method': 'PUT', 'data': request.data})
elif request.method == 'PATCH':
return {'method': 'PATCH', 'data': request.data}
return Response({'method': 'PATCH', 'data': request.data})
@api_view(['GET', 'POST', 'PUT', 'PATCH'])
async def basic_async_view(request):
if request.method == 'GET':
return Response({'method': 'GET'})
elif request.method == 'POST':
return Response({'method': 'POST', 'data': request.data})
elif request.method == 'PUT':
return Response({'method': 'PUT', 'data': request.data})
elif request.method == 'PATCH':
return Response({'method': 'PATCH', 'data': request.data})
class ErrorView(APIView):
@ -72,6 +96,36 @@ class ClassBasedViewIntegrationTests(TestCase):
def setUp(self):
self.view = BasicView.as_view()
def test_get_succeeds(self):
request = factory.get('/')
response = self.view(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {'method': 'GET'}
def test_logged_in_get_succeeds(self):
user = User.objects.create_user('user', 'user@example.com', 'password')
request = factory.get('/')
del user.is_active
request.user = user
response = self.view(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {'method': 'GET'}
def test_post_succeeds(self):
request = factory.post('/', {'test': 'foo'})
response = self.view(request)
expected = {
'method': 'POST',
'data': {'test': ['foo']}
}
assert response.status_code == status.HTTP_200_OK
assert response.data == expected
def test_options_succeeds(self):
request = factory.options('/')
response = self.view(request)
assert response.status_code == status.HTTP_200_OK
def test_400_parse_error(self):
request = factory.post('/', 'f00bar', content_type='application/json')
response = self.view(request)
@ -82,10 +136,88 @@ class ClassBasedViewIntegrationTests(TestCase):
assert sanitise_json_error(response.data) == expected
@pytest.mark.skipif(
django.VERSION < (4, 1),
reason="Async view support requires Django 4.1 or higher",
)
class ClassBasedAsyncViewIntegrationTests(TestCase):
def setUp(self):
self.view = BasicAsyncView.as_view()
def test_get_succeeds(self):
request = factory.get('/')
response = async_to_sync(self.view)(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {'method': 'GET'}
def test_logged_in_get_succeeds(self):
user = User.objects.create_user('user', 'user@example.com', 'password')
request = factory.get('/')
del user.is_active
request.user = user
response = async_to_sync(self.view)(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {'method': 'GET'}
def test_post_succeeds(self):
request = factory.post('/', {'test': 'foo'})
response = async_to_sync(self.view)(request)
expected = {
'method': 'POST',
'data': {'test': ['foo']}
}
assert response.status_code == status.HTTP_200_OK
assert response.data == expected
def test_options_succeeds(self):
request = factory.options('/')
response = async_to_sync(self.view)(request)
assert response.status_code == status.HTTP_200_OK
def test_400_parse_error(self):
request = factory.post('/', 'f00bar', content_type='application/json')
response = async_to_sync(self.view)(request)
expected = {
'detail': JSON_ERROR
}
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert sanitise_json_error(response.data) == expected
class FunctionBasedViewIntegrationTests(TestCase):
def setUp(self):
self.view = basic_view
def test_get_succeeds(self):
request = factory.get('/')
response = self.view(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {'method': 'GET'}
def test_logged_in_get_succeeds(self):
user = User.objects.create_user('user', 'user@example.com', 'password')
request = factory.get('/')
del user.is_active
request.user = user
response = self.view(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {'method': 'GET'}
def test_post_succeeds(self):
request = factory.post('/', {'test': 'foo'})
response = self.view(request)
expected = {
'method': 'POST',
'data': {'test': ['foo']}
}
assert response.status_code == status.HTTP_200_OK
assert response.data == expected
def test_options_succeeds(self):
request = factory.options('/')
response = self.view(request)
assert response.status_code == status.HTTP_200_OK
def test_400_parse_error(self):
request = factory.post('/', 'f00bar', content_type='application/json')
response = self.view(request)
@ -96,6 +228,54 @@ class FunctionBasedViewIntegrationTests(TestCase):
assert sanitise_json_error(response.data) == expected
@pytest.mark.skipif(
django.VERSION < (4, 1),
reason="Async view support requires Django 4.1 or higher",
)
class FunctionBasedAsyncViewIntegrationTests(TestCase):
def setUp(self):
self.view = basic_async_view
def test_get_succeeds(self):
request = factory.get('/')
response = async_to_sync(self.view)(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {'method': 'GET'}
def test_logged_in_get_succeeds(self):
user = User.objects.create_user('user', 'user@example.com', 'password')
request = factory.get('/')
del user.is_active
request.user = user
response = async_to_sync(self.view)(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {'method': 'GET'}
def test_post_succeeds(self):
request = factory.post('/', {'test': 'foo'})
response = async_to_sync(self.view)(request)
expected = {
'method': 'POST',
'data': {'test': ['foo']}
}
assert response.status_code == status.HTTP_200_OK
assert response.data == expected
def test_options_succeeds(self):
request = factory.options('/')
response = async_to_sync(self.view)(request)
assert response.status_code == status.HTTP_200_OK
def test_400_parse_error(self):
request = factory.post('/', 'f00bar', content_type='application/json')
response = async_to_sync(self.view)(request)
expected = {
'detail': JSON_ERROR
}
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert sanitise_json_error(response.data) == expected
class TestCustomExceptionHandler(TestCase):
def setUp(self):
self.DEFAULT_HANDLER = api_settings.EXCEPTION_HANDLER