mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-07-18 04:02:35 +03:00
Async view implementation
This commit is contained in:
parent
4f7e9ed3bb
commit
979bb24bec
1
.github/workflows/main.yml
vendored
1
.github/workflows/main.yml
vendored
|
@ -12,6 +12,7 @@ jobs:
|
||||||
runs-on: ubuntu-20.04
|
runs-on: ubuntu-20.04
|
||||||
|
|
||||||
strategy:
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python-version:
|
python-version:
|
||||||
- '3.6'
|
- '3.6'
|
||||||
|
|
|
@ -217,6 +217,22 @@ You may pass `None` in order to exclude the view from schema generation.
|
||||||
def view(request):
|
def view(request):
|
||||||
return Response({"message": "Will not appear in schema!"})
|
return Response({"message": "Will not appear in schema!"})
|
||||||
|
|
||||||
|
# Async Views
|
||||||
|
|
||||||
|
When using Django 4.1 and above, REST framework allows you to work with async class and function based views.
|
||||||
|
|
||||||
|
For class based views, all handler methods must be async, otherwise Django will raise an exception. For function based views, the function itself must be async.
|
||||||
|
|
||||||
|
For example:
|
||||||
|
|
||||||
|
class AsyncView(APIView):
|
||||||
|
async def get(self, request):
|
||||||
|
return Response({"message": "This is an async class based view."})
|
||||||
|
|
||||||
|
|
||||||
|
@api_view(['GET'])
|
||||||
|
async def async_view(request):
|
||||||
|
return Response({"message": "This is an async function based view."})
|
||||||
|
|
||||||
[cite]: https://reinout.vanrees.org/weblog/2011/08/24/class-based-views-usage.html
|
[cite]: https://reinout.vanrees.org/weblog/2011/08/24/class-based-views-usage.html
|
||||||
[cite2]: http://www.boredomandlaziness.org/2012/05/djangos-cbvs-are-not-mistake-but.html
|
[cite2]: http://www.boredomandlaziness.org/2012/05/djangos-cbvs-are-not-mistake-but.html
|
||||||
|
|
|
@ -41,6 +41,17 @@ except ImportError:
|
||||||
uritemplate = None
|
uritemplate = None
|
||||||
|
|
||||||
|
|
||||||
|
# async_to_sync is required for async view support
|
||||||
|
if django.VERSION >= (4, 1):
|
||||||
|
from asgiref.sync import async_to_sync, iscoroutinefunction, sync_to_async
|
||||||
|
else:
|
||||||
|
async_to_sync = None
|
||||||
|
sync_to_async = None
|
||||||
|
|
||||||
|
def iscoroutinefunction(func):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
# coreschema is optional
|
# coreschema is optional
|
||||||
try:
|
try:
|
||||||
import coreschema
|
import coreschema
|
||||||
|
|
|
@ -10,6 +10,7 @@ import types
|
||||||
|
|
||||||
from django.forms.utils import pretty_name
|
from django.forms.utils import pretty_name
|
||||||
|
|
||||||
|
from rest_framework.compat import iscoroutinefunction
|
||||||
from rest_framework.views import APIView
|
from rest_framework.views import APIView
|
||||||
|
|
||||||
|
|
||||||
|
@ -46,8 +47,12 @@ def api_view(http_method_names=None):
|
||||||
allowed_methods = set(http_method_names) | {'options'}
|
allowed_methods = set(http_method_names) | {'options'}
|
||||||
WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods]
|
WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods]
|
||||||
|
|
||||||
def handler(self, *args, **kwargs):
|
if iscoroutinefunction(func):
|
||||||
return func(*args, **kwargs)
|
async def handler(self, *args, **kwargs):
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
def handler(self, *args, **kwargs):
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
for method in http_method_names:
|
for method in http_method_names:
|
||||||
setattr(WrappedAPIView, method.lower(), handler)
|
setattr(WrappedAPIView, method.lower(), handler)
|
||||||
|
|
|
@ -11,6 +11,10 @@ from django.test import override_settings, testcases
|
||||||
from django.test.client import Client as DjangoClient
|
from django.test.client import Client as DjangoClient
|
||||||
from django.test.client import ClientHandler
|
from django.test.client import ClientHandler
|
||||||
from django.test.client import RequestFactory as DjangoRequestFactory
|
from django.test.client import RequestFactory as DjangoRequestFactory
|
||||||
|
|
||||||
|
if django.VERSION >= (4, 1):
|
||||||
|
from django.test.client import AsyncRequestFactory as DjangoAsyncRequestFactory
|
||||||
|
|
||||||
from django.utils.encoding import force_bytes
|
from django.utils.encoding import force_bytes
|
||||||
from django.utils.http import urlencode
|
from django.utils.http import urlencode
|
||||||
|
|
||||||
|
@ -136,7 +140,7 @@ else:
|
||||||
raise ImproperlyConfigured('coreapi must be installed in order to use CoreAPIClient.')
|
raise ImproperlyConfigured('coreapi must be installed in order to use CoreAPIClient.')
|
||||||
|
|
||||||
|
|
||||||
class APIRequestFactory(DjangoRequestFactory):
|
class APIRequestFactoryMixin:
|
||||||
renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES
|
renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES
|
||||||
default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT
|
default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT
|
||||||
|
|
||||||
|
@ -240,6 +244,15 @@ class APIRequestFactory(DjangoRequestFactory):
|
||||||
return request
|
return request
|
||||||
|
|
||||||
|
|
||||||
|
class APIRequestFactory(APIRequestFactoryMixin, DjangoRequestFactory):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if django.VERSION >= (4, 1):
|
||||||
|
class APIAsyncRequestFactory(APIRequestFactoryMixin, DjangoAsyncRequestFactory):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ForceAuthClientHandler(ClientHandler):
|
class ForceAuthClientHandler(ClientHandler):
|
||||||
"""
|
"""
|
||||||
A patched version of ClientHandler that can enforce authentication
|
A patched version of ClientHandler that can enforce authentication
|
||||||
|
|
|
@ -6,6 +6,9 @@ import time
|
||||||
from django.core.cache import cache as default_cache
|
from django.core.cache import cache as default_cache
|
||||||
from django.core.exceptions import ImproperlyConfigured
|
from django.core.exceptions import ImproperlyConfigured
|
||||||
|
|
||||||
|
from rest_framework.compat import (
|
||||||
|
async_to_sync, iscoroutinefunction, sync_to_async
|
||||||
|
)
|
||||||
from rest_framework.settings import api_settings
|
from rest_framework.settings import api_settings
|
||||||
|
|
||||||
|
|
||||||
|
@ -64,6 +67,8 @@ class SimpleRateThrottle(BaseThrottle):
|
||||||
cache_format = 'throttle_%(scope)s_%(ident)s'
|
cache_format = 'throttle_%(scope)s_%(ident)s'
|
||||||
scope = None
|
scope = None
|
||||||
THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
|
THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
|
||||||
|
sync_capable = True
|
||||||
|
async_capable = True
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
if not getattr(self, 'rate', None):
|
if not getattr(self, 'rate', None):
|
||||||
|
@ -113,23 +118,52 @@ class SimpleRateThrottle(BaseThrottle):
|
||||||
On success calls `throttle_success`.
|
On success calls `throttle_success`.
|
||||||
On failure calls `throttle_failure`.
|
On failure calls `throttle_failure`.
|
||||||
"""
|
"""
|
||||||
if self.rate is None:
|
if getattr(view, 'view_is_async', False):
|
||||||
return True
|
|
||||||
|
|
||||||
self.key = self.get_cache_key(request, view)
|
async def func():
|
||||||
if self.key is None:
|
if self.rate is None:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
self.history = self.cache.get(self.key, [])
|
self.key = self.get_cache_key(request, view)
|
||||||
self.now = self.timer()
|
if self.key is None:
|
||||||
|
return True
|
||||||
|
|
||||||
# Drop any requests from the history which have now passed the
|
self.history = self.cache.get(self.key, [])
|
||||||
# throttle duration
|
if iscoroutinefunction(self.timer):
|
||||||
while self.history and self.history[-1] <= self.now - self.duration:
|
self.now = await self.timer()
|
||||||
self.history.pop()
|
else:
|
||||||
if len(self.history) >= self.num_requests:
|
self.now = await sync_to_async(self.timer)()
|
||||||
return self.throttle_failure()
|
|
||||||
return self.throttle_success()
|
# Drop any requests from the history which have now passed the
|
||||||
|
# throttle duration
|
||||||
|
while self.history and self.history[-1] <= self.now - self.duration:
|
||||||
|
self.history.pop()
|
||||||
|
if len(self.history) >= self.num_requests:
|
||||||
|
return self.throttle_failure()
|
||||||
|
return self.throttle_success()
|
||||||
|
|
||||||
|
return func()
|
||||||
|
else:
|
||||||
|
if self.rate is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
self.key = self.get_cache_key(request, view)
|
||||||
|
if self.key is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
self.history = self.cache.get(self.key, [])
|
||||||
|
if iscoroutinefunction(self.timer):
|
||||||
|
self.now = async_to_sync(self.timer)()
|
||||||
|
else:
|
||||||
|
self.now = self.timer()
|
||||||
|
|
||||||
|
# Drop any requests from the history which have now passed the
|
||||||
|
# throttle duration
|
||||||
|
while self.history and self.history[-1] <= self.now - self.duration:
|
||||||
|
self.history.pop()
|
||||||
|
if len(self.history) >= self.num_requests:
|
||||||
|
return self.throttle_failure()
|
||||||
|
return self.throttle_success()
|
||||||
|
|
||||||
def throttle_success(self):
|
def throttle_success(self):
|
||||||
"""
|
"""
|
||||||
|
@ -210,6 +244,8 @@ class ScopedRateThrottle(SimpleRateThrottle):
|
||||||
user id of the request, and the scope of the view being accessed.
|
user id of the request, and the scope of the view being accessed.
|
||||||
"""
|
"""
|
||||||
scope_attr = 'throttle_scope'
|
scope_attr = 'throttle_scope'
|
||||||
|
sync_capable = True
|
||||||
|
async_capable = True
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# Override the usual SimpleRateThrottle, because we can't determine
|
# Override the usual SimpleRateThrottle, because we can't determine
|
||||||
|
@ -220,17 +256,34 @@ class ScopedRateThrottle(SimpleRateThrottle):
|
||||||
# We can only determine the scope once we're called by the view.
|
# We can only determine the scope once we're called by the view.
|
||||||
self.scope = getattr(view, self.scope_attr, None)
|
self.scope = getattr(view, self.scope_attr, None)
|
||||||
|
|
||||||
# If a view does not have a `throttle_scope` always allow the request
|
if getattr(view, 'view_is_async', False):
|
||||||
if not self.scope:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Determine the allowed request rate as we normally would during
|
async def func(allow_request):
|
||||||
# the `__init__` call.
|
# If a view does not have a `throttle_scope` always allow the request
|
||||||
self.rate = self.get_rate()
|
if not self.scope:
|
||||||
self.num_requests, self.duration = self.parse_rate(self.rate)
|
return True
|
||||||
|
|
||||||
# We can now proceed as normal.
|
# Determine the allowed request rate as we normally would during
|
||||||
return super().allow_request(request, view)
|
# the `__init__` call.
|
||||||
|
self.rate = self.get_rate()
|
||||||
|
self.num_requests, self.duration = self.parse_rate(self.rate)
|
||||||
|
|
||||||
|
# We can now proceed as normal.
|
||||||
|
return await allow_request(request, view)
|
||||||
|
|
||||||
|
return func(super().allow_request)
|
||||||
|
else:
|
||||||
|
# If a view does not have a `throttle_scope` always allow the request
|
||||||
|
if not self.scope:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Determine the allowed request rate as we normally would during
|
||||||
|
# the `__init__` call.
|
||||||
|
self.rate = self.get_rate()
|
||||||
|
self.num_requests, self.duration = self.parse_rate(self.rate)
|
||||||
|
|
||||||
|
# We can now proceed as normal.
|
||||||
|
return super().allow_request(request, view)
|
||||||
|
|
||||||
def get_cache_key(self, request, view):
|
def get_cache_key(self, request, view):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
"""
|
"""
|
||||||
Provides an APIView class that is the base of all views in REST framework.
|
Provides an APIView class that is the base of all views in REST framework.
|
||||||
"""
|
"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.core.exceptions import PermissionDenied
|
from django.core.exceptions import PermissionDenied
|
||||||
from django.db import connections, models
|
from django.db import connections, models
|
||||||
|
@ -12,6 +14,9 @@ from django.views.decorators.csrf import csrf_exempt
|
||||||
from django.views.generic import View
|
from django.views.generic import View
|
||||||
|
|
||||||
from rest_framework import exceptions, status
|
from rest_framework import exceptions, status
|
||||||
|
from rest_framework.compat import (
|
||||||
|
async_to_sync, iscoroutinefunction, sync_to_async
|
||||||
|
)
|
||||||
from rest_framework.request import Request
|
from rest_framework.request import Request
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from rest_framework.schemas import DefaultSchema
|
from rest_framework.schemas import DefaultSchema
|
||||||
|
@ -328,13 +333,52 @@ class APIView(View):
|
||||||
Check if the request should be permitted.
|
Check if the request should be permitted.
|
||||||
Raises an appropriate exception if the request is not permitted.
|
Raises an appropriate exception if the request is not permitted.
|
||||||
"""
|
"""
|
||||||
|
async_permissions, sync_permissions = [], []
|
||||||
for permission in self.get_permissions():
|
for permission in self.get_permissions():
|
||||||
if not permission.has_permission(request, self):
|
if iscoroutinefunction(permission.has_permission):
|
||||||
self.permission_denied(
|
async_permissions.append(permission)
|
||||||
request,
|
else:
|
||||||
message=getattr(permission, 'message', None),
|
sync_permissions.append(permission)
|
||||||
code=getattr(permission, 'code', None)
|
|
||||||
)
|
async def check_async():
|
||||||
|
results = await asyncio.gather(
|
||||||
|
*(permission.has_permission(request, self) for permission in
|
||||||
|
async_permissions), return_exceptions=True
|
||||||
|
)
|
||||||
|
|
||||||
|
for idx in range(len(async_permissions)):
|
||||||
|
if isinstance(results[idx], Exception):
|
||||||
|
raise results[idx]
|
||||||
|
elif not results[idx]:
|
||||||
|
self.permission_denied(
|
||||||
|
request,
|
||||||
|
message=getattr(async_permissions[idx], "message", None),
|
||||||
|
code=getattr(async_permissions[idx], "code", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_sync():
|
||||||
|
for permission in sync_permissions:
|
||||||
|
if not permission.has_permission(request, self):
|
||||||
|
self.permission_denied(
|
||||||
|
request,
|
||||||
|
message=getattr(permission, 'message', None),
|
||||||
|
code=getattr(permission, 'code', None)
|
||||||
|
)
|
||||||
|
|
||||||
|
if getattr(self, 'view_is_async', False):
|
||||||
|
|
||||||
|
async def func():
|
||||||
|
if async_permissions:
|
||||||
|
await check_async()
|
||||||
|
if sync_permissions:
|
||||||
|
await sync_to_async(check_sync)()
|
||||||
|
|
||||||
|
return func()
|
||||||
|
else:
|
||||||
|
if sync_permissions:
|
||||||
|
check_sync()
|
||||||
|
if async_permissions:
|
||||||
|
async_to_sync(check_async)
|
||||||
|
|
||||||
def check_object_permissions(self, request, obj):
|
def check_object_permissions(self, request, obj):
|
||||||
"""
|
"""
|
||||||
|
@ -354,21 +398,79 @@ class APIView(View):
|
||||||
Check if request should be throttled.
|
Check if request should be throttled.
|
||||||
Raises an appropriate exception if the request is throttled.
|
Raises an appropriate exception if the request is throttled.
|
||||||
"""
|
"""
|
||||||
throttle_durations = []
|
async_throttle_durations, sync_throttle_durations = [], []
|
||||||
|
view_is_async = getattr(self, 'view_is_async', False)
|
||||||
for throttle in self.get_throttles():
|
for throttle in self.get_throttles():
|
||||||
if not throttle.allow_request(request, self):
|
throttle_can_sync = getattr(throttle, "sync_capable", True)
|
||||||
throttle_durations.append(throttle.wait())
|
throttle_can_async = getattr(throttle, "async_capable", False)
|
||||||
|
if not throttle_can_sync and not throttle_can_async:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Throttle %s must have at least one of "
|
||||||
|
"sync_capable/async_capable set to True." % throttle.__class__.__name__
|
||||||
|
)
|
||||||
|
elif not view_is_async and throttle_can_sync:
|
||||||
|
throttle_is_async = False
|
||||||
|
elif iscoroutinefunction(throttle.allow_request):
|
||||||
|
throttle_is_async = True
|
||||||
|
else:
|
||||||
|
throttle_is_async = throttle_can_async
|
||||||
|
if throttle_is_async:
|
||||||
|
async_throttle_durations.append(throttle)
|
||||||
|
else:
|
||||||
|
sync_throttle_durations.append(throttle)
|
||||||
|
|
||||||
if throttle_durations:
|
async def async_throttles():
|
||||||
# Filter out `None` values which may happen in case of config / rate
|
for throttle in async_throttle_durations:
|
||||||
# changes, see #1438
|
if not await throttle.allow_request(request, self):
|
||||||
durations = [
|
yield throttle.wait()
|
||||||
duration for duration in throttle_durations
|
|
||||||
if duration is not None
|
|
||||||
]
|
|
||||||
|
|
||||||
duration = max(durations, default=None)
|
def sync_throttles():
|
||||||
self.throttled(request, duration)
|
for throttle in sync_throttle_durations:
|
||||||
|
if not throttle.allow_request(request, self):
|
||||||
|
yield throttle.wait()
|
||||||
|
|
||||||
|
if view_is_async:
|
||||||
|
|
||||||
|
async def func():
|
||||||
|
throttle_durations = []
|
||||||
|
|
||||||
|
if async_throttle_durations:
|
||||||
|
throttle_durations.extend([duration async for duration in async_throttles()])
|
||||||
|
|
||||||
|
if sync_throttle_durations:
|
||||||
|
throttle_durations.extend(duration for duration in sync_throttles())
|
||||||
|
|
||||||
|
if throttle_durations:
|
||||||
|
# Filter out `None` values which may happen in case of config / rate
|
||||||
|
# changes, see #1438
|
||||||
|
durations = [
|
||||||
|
duration for duration in throttle_durations
|
||||||
|
if duration is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
duration = max(durations, default=None)
|
||||||
|
self.throttled(request, duration)
|
||||||
|
|
||||||
|
return func()
|
||||||
|
else:
|
||||||
|
throttle_durations = []
|
||||||
|
|
||||||
|
if sync_throttle_durations:
|
||||||
|
throttle_durations.extend(sync_throttles())
|
||||||
|
|
||||||
|
if async_throttle_durations:
|
||||||
|
throttle_durations.extend(async_to_sync(async_throttles)())
|
||||||
|
|
||||||
|
if throttle_durations:
|
||||||
|
# Filter out `None` values which may happen in case of config / rate
|
||||||
|
# changes, see #1438
|
||||||
|
durations = [
|
||||||
|
duration for duration in throttle_durations
|
||||||
|
if duration is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
duration = max(durations, default=None)
|
||||||
|
self.throttled(request, duration)
|
||||||
|
|
||||||
def determine_version(self, request, *args, **kwargs):
|
def determine_version(self, request, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -410,10 +512,20 @@ class APIView(View):
|
||||||
version, scheme = self.determine_version(request, *args, **kwargs)
|
version, scheme = self.determine_version(request, *args, **kwargs)
|
||||||
request.version, request.versioning_scheme = version, scheme
|
request.version, request.versioning_scheme = version, scheme
|
||||||
|
|
||||||
# Ensure that the incoming request is permitted
|
if getattr(self, 'view_is_async', False):
|
||||||
self.perform_authentication(request)
|
|
||||||
self.check_permissions(request)
|
async def func():
|
||||||
self.check_throttles(request)
|
# Ensure that the incoming request is permitted
|
||||||
|
await sync_to_async(self.perform_authentication)(request)
|
||||||
|
await self.check_permissions(request)
|
||||||
|
await self.check_throttles(request)
|
||||||
|
|
||||||
|
return func()
|
||||||
|
else:
|
||||||
|
# Ensure that the incoming request is permitted
|
||||||
|
self.perform_authentication(request)
|
||||||
|
self.check_permissions(request)
|
||||||
|
self.check_throttles(request)
|
||||||
|
|
||||||
def finalize_response(self, request, response, *args, **kwargs):
|
def finalize_response(self, request, response, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -469,7 +581,15 @@ class APIView(View):
|
||||||
self.raise_uncaught_exception(exc)
|
self.raise_uncaught_exception(exc)
|
||||||
|
|
||||||
response.exception = True
|
response.exception = True
|
||||||
return response
|
|
||||||
|
if getattr(self, 'view_is_async', False):
|
||||||
|
|
||||||
|
async def func():
|
||||||
|
return response
|
||||||
|
|
||||||
|
return func()
|
||||||
|
else:
|
||||||
|
return response
|
||||||
|
|
||||||
def raise_uncaught_exception(self, exc):
|
def raise_uncaught_exception(self, exc):
|
||||||
if settings.DEBUG:
|
if settings.DEBUG:
|
||||||
|
@ -493,23 +613,49 @@ class APIView(View):
|
||||||
self.request = request
|
self.request = request
|
||||||
self.headers = self.default_response_headers # deprecate?
|
self.headers = self.default_response_headers # deprecate?
|
||||||
|
|
||||||
try:
|
if getattr(self, 'view_is_async', False):
|
||||||
self.initial(request, *args, **kwargs)
|
|
||||||
|
|
||||||
# Get the appropriate handler method
|
async def func():
|
||||||
if request.method.lower() in self.http_method_names:
|
|
||||||
handler = getattr(self, request.method.lower(),
|
|
||||||
self.http_method_not_allowed)
|
|
||||||
else:
|
|
||||||
handler = self.http_method_not_allowed
|
|
||||||
|
|
||||||
response = handler(request, *args, **kwargs)
|
try:
|
||||||
|
await self.initial(request, *args, **kwargs)
|
||||||
|
|
||||||
except Exception as exc:
|
# Get the appropriate handler method
|
||||||
response = self.handle_exception(exc)
|
if request.method.lower() in self.http_method_names:
|
||||||
|
handler = getattr(self, request.method.lower(),
|
||||||
|
self.http_method_not_allowed)
|
||||||
|
else:
|
||||||
|
handler = self.http_method_not_allowed
|
||||||
|
|
||||||
self.response = self.finalize_response(request, response, *args, **kwargs)
|
response = await handler(request, *args, **kwargs)
|
||||||
return self.response
|
|
||||||
|
except Exception as exc:
|
||||||
|
response = await self.handle_exception(exc)
|
||||||
|
|
||||||
|
return self.finalize_response(request, response, *args, **kwargs)
|
||||||
|
|
||||||
|
self.response = func()
|
||||||
|
|
||||||
|
return self.response
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
self.initial(request, *args, **kwargs)
|
||||||
|
|
||||||
|
# Get the appropriate handler method
|
||||||
|
if request.method.lower() in self.http_method_names:
|
||||||
|
handler = getattr(self, request.method.lower(),
|
||||||
|
self.http_method_not_allowed)
|
||||||
|
else:
|
||||||
|
handler = self.http_method_not_allowed
|
||||||
|
|
||||||
|
response = handler(request, *args, **kwargs)
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
response = self.handle_exception(exc)
|
||||||
|
|
||||||
|
self.response = self.finalize_response(request, response, *args, **kwargs)
|
||||||
|
|
||||||
|
return self.response
|
||||||
|
|
||||||
def options(self, request, *args, **kwargs):
|
def options(self, request, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -518,4 +664,12 @@ class APIView(View):
|
||||||
if self.metadata_class is None:
|
if self.metadata_class is None:
|
||||||
return self.http_method_not_allowed(request, *args, **kwargs)
|
return self.http_method_not_allowed(request, *args, **kwargs)
|
||||||
data = self.metadata_class().determine_metadata(request, self)
|
data = self.metadata_class().determine_metadata(request, self)
|
||||||
return Response(data, status=status.HTTP_200_OK)
|
|
||||||
|
if getattr(self, 'view_is_async', False):
|
||||||
|
|
||||||
|
async def func():
|
||||||
|
return Response(data, status=status.HTTP_200_OK)
|
||||||
|
|
||||||
|
return func()
|
||||||
|
else:
|
||||||
|
return Response(data, status=status.HTTP_200_OK)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
"""
|
"""
|
||||||
Tests for the throttling implementations in the permissions module.
|
Tests for the throttling implementations in the permissions module.
|
||||||
"""
|
"""
|
||||||
|
import django
|
||||||
import pytest
|
import pytest
|
||||||
from django.contrib.auth.models import User
|
from django.contrib.auth.models import User
|
||||||
from django.core.cache import cache
|
from django.core.cache import cache
|
||||||
|
@ -9,10 +9,15 @@ from django.core.exceptions import ImproperlyConfigured
|
||||||
from django.http import HttpRequest
|
from django.http import HttpRequest
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
|
||||||
|
from rest_framework.compat import async_to_sync
|
||||||
from rest_framework.request import Request
|
from rest_framework.request import Request
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from rest_framework.settings import api_settings
|
from rest_framework.settings import api_settings
|
||||||
from rest_framework.test import APIRequestFactory, force_authenticate
|
from rest_framework.test import APIRequestFactory, force_authenticate
|
||||||
|
|
||||||
|
if django.VERSION >= (4, 1):
|
||||||
|
from rest_framework.test import APIAsyncRequestFactory
|
||||||
|
|
||||||
from rest_framework.throttling import (
|
from rest_framework.throttling import (
|
||||||
AnonRateThrottle, BaseThrottle, ScopedRateThrottle, SimpleRateThrottle,
|
AnonRateThrottle, BaseThrottle, ScopedRateThrottle, SimpleRateThrottle,
|
||||||
UserRateThrottle
|
UserRateThrottle
|
||||||
|
@ -43,6 +48,14 @@ class NonTimeThrottle(BaseThrottle):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class NonTimeAsyncThrottle(BaseThrottle):
|
||||||
|
def allow_request(self, request, view):
|
||||||
|
if not hasattr(self.__class__, 'called'):
|
||||||
|
self.__class__.called = True
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class MockView_DoubleThrottling(APIView):
|
class MockView_DoubleThrottling(APIView):
|
||||||
throttle_classes = (User3SecRateThrottle, User6MinRateThrottle,)
|
throttle_classes = (User3SecRateThrottle, User6MinRateThrottle,)
|
||||||
|
|
||||||
|
@ -50,6 +63,13 @@ class MockView_DoubleThrottling(APIView):
|
||||||
return Response('foo')
|
return Response('foo')
|
||||||
|
|
||||||
|
|
||||||
|
class MockAsyncView_DoubleThrottling(APIView):
|
||||||
|
throttle_classes = (User3SecRateThrottle, User6MinRateThrottle,)
|
||||||
|
|
||||||
|
async def get(self, request):
|
||||||
|
return Response('foo')
|
||||||
|
|
||||||
|
|
||||||
class MockView(APIView):
|
class MockView(APIView):
|
||||||
throttle_classes = (User3SecRateThrottle,)
|
throttle_classes = (User3SecRateThrottle,)
|
||||||
|
|
||||||
|
@ -57,6 +77,13 @@ class MockView(APIView):
|
||||||
return Response('foo')
|
return Response('foo')
|
||||||
|
|
||||||
|
|
||||||
|
class MockAsyncView(APIView):
|
||||||
|
throttle_classes = (User3SecRateThrottle,)
|
||||||
|
|
||||||
|
async def get(self, request):
|
||||||
|
return Response('foo')
|
||||||
|
|
||||||
|
|
||||||
class MockView_MinuteThrottling(APIView):
|
class MockView_MinuteThrottling(APIView):
|
||||||
throttle_classes = (User3MinRateThrottle,)
|
throttle_classes = (User3MinRateThrottle,)
|
||||||
|
|
||||||
|
@ -64,6 +91,13 @@ class MockView_MinuteThrottling(APIView):
|
||||||
return Response('foo')
|
return Response('foo')
|
||||||
|
|
||||||
|
|
||||||
|
class MockAsyncView_MinuteThrottling(APIView):
|
||||||
|
throttle_classes = (User3MinRateThrottle,)
|
||||||
|
|
||||||
|
async def get(self, request):
|
||||||
|
return Response('foo')
|
||||||
|
|
||||||
|
|
||||||
class MockView_NonTimeThrottling(APIView):
|
class MockView_NonTimeThrottling(APIView):
|
||||||
throttle_classes = (NonTimeThrottle,)
|
throttle_classes = (NonTimeThrottle,)
|
||||||
|
|
||||||
|
@ -71,6 +105,13 @@ class MockView_NonTimeThrottling(APIView):
|
||||||
return Response('foo')
|
return Response('foo')
|
||||||
|
|
||||||
|
|
||||||
|
class MockAsyncView_NonTimeThrottling(APIView):
|
||||||
|
throttle_classes = (NonTimeAsyncThrottle,)
|
||||||
|
|
||||||
|
async def get(self, request):
|
||||||
|
return Response('foo')
|
||||||
|
|
||||||
|
|
||||||
class ThrottlingTests(TestCase):
|
class ThrottlingTests(TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
"""
|
"""
|
||||||
|
@ -252,12 +293,198 @@ class ThrottlingTests(TestCase):
|
||||||
self.assertFalse('Retry-After' in response)
|
self.assertFalse('Retry-After' in response)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
django.VERSION < (4, 1),
|
||||||
|
reason="Async view support requires Django 4.1 or higher",
|
||||||
|
)
|
||||||
|
class AsyncThrottlingTests(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
"""
|
||||||
|
Reset the cache so that no throttles will be active
|
||||||
|
"""
|
||||||
|
cache.clear()
|
||||||
|
self.factory = APIAsyncRequestFactory()
|
||||||
|
|
||||||
|
def test_requests_are_throttled(self):
|
||||||
|
"""
|
||||||
|
Ensure request rate is limited
|
||||||
|
"""
|
||||||
|
request = self.factory.get('/')
|
||||||
|
for dummy in range(4):
|
||||||
|
response = async_to_sync(MockAsyncView.as_view())(request)
|
||||||
|
assert response.status_code == 429
|
||||||
|
|
||||||
|
def set_throttle_timer(self, view, value):
|
||||||
|
"""
|
||||||
|
Explicitly set the timer, overriding time.time()
|
||||||
|
"""
|
||||||
|
for cls in view.throttle_classes:
|
||||||
|
cls.timer = lambda self: value
|
||||||
|
|
||||||
|
def test_request_throttling_expires(self):
|
||||||
|
"""
|
||||||
|
Ensure request rate is limited for a limited duration only
|
||||||
|
"""
|
||||||
|
self.set_throttle_timer(MockAsyncView, 0)
|
||||||
|
|
||||||
|
request = self.factory.get('/')
|
||||||
|
for dummy in range(4):
|
||||||
|
response = async_to_sync(MockAsyncView.as_view())(request)
|
||||||
|
assert response.status_code == 429
|
||||||
|
|
||||||
|
# Advance the timer by one second
|
||||||
|
self.set_throttle_timer(MockAsyncView, 1)
|
||||||
|
|
||||||
|
response = async_to_sync(MockAsyncView.as_view())(request)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
async def ensure_is_throttled(self, view, expect):
|
||||||
|
request = self.factory.get('/')
|
||||||
|
request.user = await User.objects.acreate(username='a')
|
||||||
|
for dummy in range(3):
|
||||||
|
await view.as_view()(request)
|
||||||
|
request.user = await User.objects.acreate(username='b')
|
||||||
|
response = await view.as_view()(request)
|
||||||
|
assert response.status_code == expect
|
||||||
|
|
||||||
|
def test_request_throttling_is_per_user(self):
|
||||||
|
"""
|
||||||
|
Ensure request rate is only limited per user, not globally for
|
||||||
|
PerUserThrottles
|
||||||
|
"""
|
||||||
|
async_to_sync(self.ensure_is_throttled)(MockAsyncView, 200)
|
||||||
|
|
||||||
|
def test_request_throttling_multiple_throttles(self):
|
||||||
|
"""
|
||||||
|
Ensure all throttle classes see each request even when the request is
|
||||||
|
already being throttled
|
||||||
|
"""
|
||||||
|
self.set_throttle_timer(MockAsyncView_DoubleThrottling, 0)
|
||||||
|
request = self.factory.get('/')
|
||||||
|
for dummy in range(4):
|
||||||
|
response = async_to_sync(MockAsyncView_DoubleThrottling.as_view())(request)
|
||||||
|
assert response.status_code == 429
|
||||||
|
assert int(response['retry-after']) == 1
|
||||||
|
|
||||||
|
# At this point our client made 4 requests (one was throttled) in a
|
||||||
|
# second. If we advance the timer by one additional second, the client
|
||||||
|
# should be allowed to make 2 more before being throttled by the 2nd
|
||||||
|
# throttle class, which has a limit of 6 per minute.
|
||||||
|
self.set_throttle_timer(MockAsyncView_DoubleThrottling, 1)
|
||||||
|
for dummy in range(2):
|
||||||
|
response = async_to_sync(MockAsyncView_DoubleThrottling.as_view())(request)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
response = async_to_sync(MockAsyncView_DoubleThrottling.as_view())(request)
|
||||||
|
assert response.status_code == 429
|
||||||
|
assert int(response['retry-after']) == 59
|
||||||
|
|
||||||
|
# Just to make sure check again after two more seconds.
|
||||||
|
self.set_throttle_timer(MockAsyncView_DoubleThrottling, 2)
|
||||||
|
response = async_to_sync(MockAsyncView_DoubleThrottling.as_view())(request)
|
||||||
|
assert response.status_code == 429
|
||||||
|
assert int(response['retry-after']) == 58
|
||||||
|
|
||||||
|
def test_throttle_rate_change_negative(self):
|
||||||
|
self.set_throttle_timer(MockAsyncView_DoubleThrottling, 0)
|
||||||
|
request = self.factory.get('/')
|
||||||
|
for dummy in range(24):
|
||||||
|
response = async_to_sync(MockAsyncView_DoubleThrottling.as_view())(request)
|
||||||
|
assert response.status_code == 429
|
||||||
|
assert int(response['retry-after']) == 60
|
||||||
|
|
||||||
|
previous_rate = User3SecRateThrottle.rate
|
||||||
|
try:
|
||||||
|
User3SecRateThrottle.rate = '1/sec'
|
||||||
|
|
||||||
|
for dummy in range(24):
|
||||||
|
response = async_to_sync(MockAsyncView_DoubleThrottling.as_view())(request)
|
||||||
|
|
||||||
|
assert response.status_code == 429
|
||||||
|
assert int(response['retry-after']) == 60
|
||||||
|
finally:
|
||||||
|
# reset
|
||||||
|
User3SecRateThrottle.rate = previous_rate
|
||||||
|
|
||||||
|
async def ensure_response_header_contains_proper_throttle_field(self, view, expected_headers):
|
||||||
|
"""
|
||||||
|
Ensure the response returns an Retry-After field with status and next attributes
|
||||||
|
set properly.
|
||||||
|
"""
|
||||||
|
request = self.factory.get('/')
|
||||||
|
for timer, expect in expected_headers:
|
||||||
|
self.set_throttle_timer(view, timer)
|
||||||
|
response = await view.as_view()(request)
|
||||||
|
if expect is not None:
|
||||||
|
assert response['Retry-After'] == expect
|
||||||
|
else:
|
||||||
|
assert not'Retry-After' in response
|
||||||
|
|
||||||
|
def test_seconds_fields(self):
|
||||||
|
"""
|
||||||
|
Ensure for second based throttles.
|
||||||
|
"""
|
||||||
|
async_to_sync(self.ensure_response_header_contains_proper_throttle_field)(
|
||||||
|
MockAsyncView, (
|
||||||
|
(0, None),
|
||||||
|
(0, None),
|
||||||
|
(0, None),
|
||||||
|
(0, '1')
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_minutes_fields(self):
|
||||||
|
"""
|
||||||
|
Ensure for minute based throttles.
|
||||||
|
"""
|
||||||
|
async_to_sync(self.ensure_response_header_contains_proper_throttle_field)(
|
||||||
|
MockAsyncView_MinuteThrottling, (
|
||||||
|
(0, None),
|
||||||
|
(0, None),
|
||||||
|
(0, None),
|
||||||
|
(0, '60')
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_next_rate_remains_constant_if_followed(self):
|
||||||
|
"""
|
||||||
|
If a client follows the recommended next request rate,
|
||||||
|
the throttling rate should stay constant.
|
||||||
|
"""
|
||||||
|
async_to_sync(self.ensure_response_header_contains_proper_throttle_field)(
|
||||||
|
MockAsyncView_MinuteThrottling, (
|
||||||
|
(0, None),
|
||||||
|
(20, None),
|
||||||
|
(40, None),
|
||||||
|
(60, None),
|
||||||
|
(80, None)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_non_time_throttle(self):
|
||||||
|
"""
|
||||||
|
Ensure for second based throttles.
|
||||||
|
"""
|
||||||
|
request = self.factory.get('/')
|
||||||
|
|
||||||
|
self.assertFalse(hasattr(MockAsyncView_NonTimeThrottling.throttle_classes[0], 'called'))
|
||||||
|
|
||||||
|
response = async_to_sync(MockAsyncView_NonTimeThrottling.as_view())(request)
|
||||||
|
self.assertFalse('Retry-After' in response)
|
||||||
|
|
||||||
|
self.assertTrue(MockAsyncView_NonTimeThrottling.throttle_classes[0].called)
|
||||||
|
|
||||||
|
response = async_to_sync(MockAsyncView_NonTimeThrottling.as_view())(request)
|
||||||
|
self.assertFalse('Retry-After' in response)
|
||||||
|
|
||||||
|
|
||||||
class ScopedRateThrottleTests(TestCase):
|
class ScopedRateThrottleTests(TestCase):
|
||||||
"""
|
"""
|
||||||
Tests for ScopedRateThrottle.
|
Tests for ScopedRateThrottle.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
cache.clear()
|
||||||
self.throttle = ScopedRateThrottle()
|
self.throttle = ScopedRateThrottle()
|
||||||
|
|
||||||
class XYScopedRateThrottle(ScopedRateThrottle):
|
class XYScopedRateThrottle(ScopedRateThrottle):
|
||||||
|
@ -372,6 +599,131 @@ class ScopedRateThrottleTests(TestCase):
|
||||||
assert cache_key == 'throttle_user_%s' % user.pk
|
assert cache_key == 'throttle_user_%s' % user.pk
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
django.VERSION < (4, 1),
|
||||||
|
reason="Async view support requires Django 4.1 or higher",
|
||||||
|
)
|
||||||
|
class AsyncScopedRateThrottleTests(TestCase):
|
||||||
|
"""
|
||||||
|
Tests for ScopedRateThrottle.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
cache.clear()
|
||||||
|
self.throttle = ScopedRateThrottle()
|
||||||
|
|
||||||
|
class XYScopedRateThrottle(ScopedRateThrottle):
|
||||||
|
TIMER_SECONDS = 0
|
||||||
|
THROTTLE_RATES = {'x': '3/min', 'y': '1/min'}
|
||||||
|
|
||||||
|
async def timer(self):
|
||||||
|
return self.TIMER_SECONDS
|
||||||
|
|
||||||
|
class XView(APIView):
|
||||||
|
throttle_classes = (XYScopedRateThrottle,)
|
||||||
|
throttle_scope = 'x'
|
||||||
|
|
||||||
|
async def get(self, request):
|
||||||
|
return Response('x')
|
||||||
|
|
||||||
|
class YView(APIView):
|
||||||
|
throttle_classes = (XYScopedRateThrottle,)
|
||||||
|
throttle_scope = 'y'
|
||||||
|
|
||||||
|
async def get(self, request):
|
||||||
|
return Response('y')
|
||||||
|
|
||||||
|
class UnscopedView(APIView):
|
||||||
|
throttle_classes = (XYScopedRateThrottle,)
|
||||||
|
|
||||||
|
async def get(self, request):
|
||||||
|
return Response('y')
|
||||||
|
|
||||||
|
self.throttle_class = XYScopedRateThrottle
|
||||||
|
self.factory = APIAsyncRequestFactory()
|
||||||
|
self.x_view = XView.as_view()
|
||||||
|
self.y_view = YView.as_view()
|
||||||
|
self.unscoped_view = UnscopedView.as_view()
|
||||||
|
|
||||||
|
def increment_timer(self, seconds=1):
|
||||||
|
self.throttle_class.TIMER_SECONDS += seconds
|
||||||
|
|
||||||
|
def test_scoped_rate_throttle(self):
|
||||||
|
request = self.factory.get('/')
|
||||||
|
|
||||||
|
# Should be able to hit x view 3 times per minute.
|
||||||
|
response = async_to_sync(self.x_view)(request)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
self.increment_timer()
|
||||||
|
response = async_to_sync(self.x_view)(request)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
self.increment_timer()
|
||||||
|
response = async_to_sync(self.x_view)(request)
|
||||||
|
assert response.status_code == 200
|
||||||
|
self.increment_timer()
|
||||||
|
response = async_to_sync(self.x_view)(request)
|
||||||
|
assert response.status_code == 429
|
||||||
|
|
||||||
|
# Should be able to hit y view 1 time per minute.
|
||||||
|
self.increment_timer()
|
||||||
|
response = async_to_sync(self.y_view)(request)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
self.increment_timer()
|
||||||
|
response = async_to_sync(self.y_view)(request)
|
||||||
|
assert response.status_code == 429
|
||||||
|
|
||||||
|
# Ensure throttles properly reset by advancing the rest of the minute
|
||||||
|
self.increment_timer(55)
|
||||||
|
|
||||||
|
# Should still be able to hit x view 3 times per minute.
|
||||||
|
response = async_to_sync(self.x_view)(request)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
self.increment_timer()
|
||||||
|
response = async_to_sync(self.x_view)(request)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
self.increment_timer()
|
||||||
|
response = async_to_sync(self.x_view)(request)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
self.increment_timer()
|
||||||
|
response = async_to_sync(self.x_view)(request)
|
||||||
|
assert response.status_code == 429
|
||||||
|
|
||||||
|
# Should still be able to hit y view 1 time per minute.
|
||||||
|
self.increment_timer()
|
||||||
|
response = async_to_sync(self.y_view)(request)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
self.increment_timer()
|
||||||
|
response = async_to_sync(self.y_view)(request)
|
||||||
|
assert response.status_code == 429
|
||||||
|
|
||||||
|
def test_unscoped_view_not_throttled(self):
|
||||||
|
request = self.factory.get('/')
|
||||||
|
|
||||||
|
for idx in range(10):
|
||||||
|
self.increment_timer()
|
||||||
|
response = async_to_sync(self.unscoped_view)(request)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def test_get_cache_key_returns_correct_key_if_user_is_authenticated(self):
|
||||||
|
class DummyView:
|
||||||
|
throttle_scope = 'user'
|
||||||
|
|
||||||
|
request = Request(HttpRequest())
|
||||||
|
user = User.objects.create(username='test')
|
||||||
|
force_authenticate(request, user)
|
||||||
|
request.user = user
|
||||||
|
self.throttle.allow_request(request, DummyView())
|
||||||
|
cache_key = self.throttle.get_cache_key(request, view=DummyView())
|
||||||
|
assert cache_key == 'throttle_user_%s' % user.pk
|
||||||
|
|
||||||
|
|
||||||
class XffTestingBase(TestCase):
|
class XffTestingBase(TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
|
||||||
|
@ -400,6 +752,34 @@ class XffTestingBase(TestCase):
|
||||||
setattr(api_settings, 'NUM_PROXIES', num_proxies)
|
setattr(api_settings, 'NUM_PROXIES', num_proxies)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncXffTestingBase(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
|
||||||
|
class Throttle(ScopedRateThrottle):
|
||||||
|
THROTTLE_RATES = {'test_limit': '1/day'}
|
||||||
|
TIMER_SECONDS = 0
|
||||||
|
|
||||||
|
async def timer(self):
|
||||||
|
return self.TIMER_SECONDS
|
||||||
|
|
||||||
|
class View(APIView):
|
||||||
|
throttle_classes = (Throttle,)
|
||||||
|
throttle_scope = 'test_limit'
|
||||||
|
|
||||||
|
async def get(self, request):
|
||||||
|
return Response('test_limit')
|
||||||
|
|
||||||
|
cache.clear()
|
||||||
|
self.throttle = Throttle()
|
||||||
|
self.view = View.as_view()
|
||||||
|
self.request = APIAsyncRequestFactory().get('/some_uri')
|
||||||
|
self.request.META['REMOTE_ADDR'] = '3.3.3.3'
|
||||||
|
self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 1.1.1.1, 2.2.2.2'
|
||||||
|
|
||||||
|
def config_proxy(self, num_proxies):
|
||||||
|
setattr(api_settings, 'NUM_PROXIES', num_proxies)
|
||||||
|
|
||||||
|
|
||||||
class IdWithXffBasicTests(XffTestingBase):
|
class IdWithXffBasicTests(XffTestingBase):
|
||||||
def test_accepts_request_under_limit(self):
|
def test_accepts_request_under_limit(self):
|
||||||
self.config_proxy(0)
|
self.config_proxy(0)
|
||||||
|
@ -411,6 +791,21 @@ class IdWithXffBasicTests(XffTestingBase):
|
||||||
assert self.view(self.request).status_code == 429
|
assert self.view(self.request).status_code == 429
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
django.VERSION < (4, 1),
|
||||||
|
reason="Async view support requires Django 4.1 or higher",
|
||||||
|
)
|
||||||
|
class AsyncIdWithXffBasicTests(AsyncXffTestingBase):
|
||||||
|
def test_accepts_request_under_limit(self):
|
||||||
|
self.config_proxy(0)
|
||||||
|
assert async_to_sync(self.view)(self.request).status_code == 200
|
||||||
|
|
||||||
|
def test_denies_request_over_limit(self):
|
||||||
|
self.config_proxy(0)
|
||||||
|
async_to_sync(self.view)(self.request)
|
||||||
|
assert async_to_sync(self.view)(self.request).status_code == 429
|
||||||
|
|
||||||
|
|
||||||
class XffSpoofingTests(XffTestingBase):
|
class XffSpoofingTests(XffTestingBase):
|
||||||
def test_xff_spoofing_doesnt_change_machine_id_with_one_app_proxy(self):
|
def test_xff_spoofing_doesnt_change_machine_id_with_one_app_proxy(self):
|
||||||
self.config_proxy(1)
|
self.config_proxy(1)
|
||||||
|
@ -425,6 +820,24 @@ class XffSpoofingTests(XffTestingBase):
|
||||||
assert self.view(self.request).status_code == 429
|
assert self.view(self.request).status_code == 429
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
django.VERSION < (4, 1),
|
||||||
|
reason="Async view support requires Django 4.1 or higher",
|
||||||
|
)
|
||||||
|
class AsyncXffSpoofingTests(AsyncXffTestingBase):
|
||||||
|
def test_xff_spoofing_doesnt_change_machine_id_with_one_app_proxy(self):
|
||||||
|
self.config_proxy(1)
|
||||||
|
async_to_sync(self.view)(self.request)
|
||||||
|
self.request.META['HTTP_X_FORWARDED_FOR'] = '4.4.4.4, 5.5.5.5, 2.2.2.2'
|
||||||
|
assert async_to_sync(self.view)(self.request).status_code == 429
|
||||||
|
|
||||||
|
def test_xff_spoofing_doesnt_change_machine_id_with_two_app_proxies(self):
|
||||||
|
self.config_proxy(2)
|
||||||
|
async_to_sync(self.view)(self.request)
|
||||||
|
self.request.META['HTTP_X_FORWARDED_FOR'] = '4.4.4.4, 1.1.1.1, 2.2.2.2'
|
||||||
|
assert async_to_sync(self.view)(self.request).status_code == 429
|
||||||
|
|
||||||
|
|
||||||
class XffUniqueMachinesTest(XffTestingBase):
|
class XffUniqueMachinesTest(XffTestingBase):
|
||||||
def test_unique_clients_are_counted_independently_with_one_proxy(self):
|
def test_unique_clients_are_counted_independently_with_one_proxy(self):
|
||||||
self.config_proxy(1)
|
self.config_proxy(1)
|
||||||
|
@ -439,6 +852,24 @@ class XffUniqueMachinesTest(XffTestingBase):
|
||||||
assert self.view(self.request).status_code == 200
|
assert self.view(self.request).status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
django.VERSION < (4, 1),
|
||||||
|
reason="Async view support requires Django 4.1 or higher",
|
||||||
|
)
|
||||||
|
class AsyncXffUniqueMachinesTest(AsyncXffTestingBase):
|
||||||
|
def test_unique_clients_are_counted_independently_with_one_proxy(self):
|
||||||
|
self.config_proxy(1)
|
||||||
|
async_to_sync(self.view)(self.request)
|
||||||
|
self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 1.1.1.1, 7.7.7.7'
|
||||||
|
assert async_to_sync(self.view)(self.request).status_code == 200
|
||||||
|
|
||||||
|
def test_unique_clients_are_counted_independently_with_two_proxies(self):
|
||||||
|
self.config_proxy(2)
|
||||||
|
async_to_sync(self.view)(self.request)
|
||||||
|
self.request.META['HTTP_X_FORWARDED_FOR'] = '0.0.0.0, 7.7.7.7, 2.2.2.2'
|
||||||
|
assert async_to_sync(self.view)(self.request).status_code == 200
|
||||||
|
|
||||||
|
|
||||||
class BaseThrottleTests(TestCase):
|
class BaseThrottleTests(TestCase):
|
||||||
|
|
||||||
def test_allow_request_raises_not_implemented_error(self):
|
def test_allow_request_raises_not_implemented_error(self):
|
||||||
|
|
|
@ -1,8 +1,12 @@
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
|
import django
|
||||||
|
import pytest
|
||||||
|
from django.contrib.auth.models import User
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
|
||||||
from rest_framework import status
|
from rest_framework import status
|
||||||
|
from rest_framework.compat import async_to_sync
|
||||||
from rest_framework.decorators import api_view
|
from rest_framework.decorators import api_view
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from rest_framework.settings import APISettings, api_settings
|
from rest_framework.settings import APISettings, api_settings
|
||||||
|
@ -22,16 +26,36 @@ class BasicView(APIView):
|
||||||
return Response({'method': 'POST', 'data': request.data})
|
return Response({'method': 'POST', 'data': request.data})
|
||||||
|
|
||||||
|
|
||||||
|
class BasicAsyncView(APIView):
|
||||||
|
async def get(self, request, *args, **kwargs):
|
||||||
|
return Response({'method': 'GET'})
|
||||||
|
|
||||||
|
async def post(self, request, *args, **kwargs):
|
||||||
|
return Response({'method': 'POST', 'data': request.data})
|
||||||
|
|
||||||
|
|
||||||
@api_view(['GET', 'POST', 'PUT', 'PATCH'])
|
@api_view(['GET', 'POST', 'PUT', 'PATCH'])
|
||||||
def basic_view(request):
|
def basic_view(request):
|
||||||
if request.method == 'GET':
|
if request.method == 'GET':
|
||||||
return {'method': 'GET'}
|
return Response({'method': 'GET'})
|
||||||
elif request.method == 'POST':
|
elif request.method == 'POST':
|
||||||
return {'method': 'POST', 'data': request.data}
|
return Response({'method': 'POST', 'data': request.data})
|
||||||
elif request.method == 'PUT':
|
elif request.method == 'PUT':
|
||||||
return {'method': 'PUT', 'data': request.data}
|
return Response({'method': 'PUT', 'data': request.data})
|
||||||
elif request.method == 'PATCH':
|
elif request.method == 'PATCH':
|
||||||
return {'method': 'PATCH', 'data': request.data}
|
return Response({'method': 'PATCH', 'data': request.data})
|
||||||
|
|
||||||
|
|
||||||
|
@api_view(['GET', 'POST', 'PUT', 'PATCH'])
|
||||||
|
async def basic_async_view(request):
|
||||||
|
if request.method == 'GET':
|
||||||
|
return Response({'method': 'GET'})
|
||||||
|
elif request.method == 'POST':
|
||||||
|
return Response({'method': 'POST', 'data': request.data})
|
||||||
|
elif request.method == 'PUT':
|
||||||
|
return Response({'method': 'PUT', 'data': request.data})
|
||||||
|
elif request.method == 'PATCH':
|
||||||
|
return Response({'method': 'PATCH', 'data': request.data})
|
||||||
|
|
||||||
|
|
||||||
class ErrorView(APIView):
|
class ErrorView(APIView):
|
||||||
|
@ -72,6 +96,36 @@ class ClassBasedViewIntegrationTests(TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.view = BasicView.as_view()
|
self.view = BasicView.as_view()
|
||||||
|
|
||||||
|
def test_get_succeeds(self):
|
||||||
|
request = factory.get('/')
|
||||||
|
response = self.view(request)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
assert response.data == {'method': 'GET'}
|
||||||
|
|
||||||
|
def test_logged_in_get_succeeds(self):
|
||||||
|
user = User.objects.create_user('user', 'user@example.com', 'password')
|
||||||
|
request = factory.get('/')
|
||||||
|
del user.is_active
|
||||||
|
request.user = user
|
||||||
|
response = self.view(request)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
assert response.data == {'method': 'GET'}
|
||||||
|
|
||||||
|
def test_post_succeeds(self):
|
||||||
|
request = factory.post('/', {'test': 'foo'})
|
||||||
|
response = self.view(request)
|
||||||
|
expected = {
|
||||||
|
'method': 'POST',
|
||||||
|
'data': {'test': ['foo']}
|
||||||
|
}
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
assert response.data == expected
|
||||||
|
|
||||||
|
def test_options_succeeds(self):
|
||||||
|
request = factory.options('/')
|
||||||
|
response = self.view(request)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
def test_400_parse_error(self):
|
def test_400_parse_error(self):
|
||||||
request = factory.post('/', 'f00bar', content_type='application/json')
|
request = factory.post('/', 'f00bar', content_type='application/json')
|
||||||
response = self.view(request)
|
response = self.view(request)
|
||||||
|
@ -82,10 +136,88 @@ class ClassBasedViewIntegrationTests(TestCase):
|
||||||
assert sanitise_json_error(response.data) == expected
|
assert sanitise_json_error(response.data) == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
django.VERSION < (4, 1),
|
||||||
|
reason="Async view support requires Django 4.1 or higher",
|
||||||
|
)
|
||||||
|
class ClassBasedAsyncViewIntegrationTests(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.view = BasicAsyncView.as_view()
|
||||||
|
|
||||||
|
def test_get_succeeds(self):
|
||||||
|
request = factory.get('/')
|
||||||
|
response = async_to_sync(self.view)(request)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
assert response.data == {'method': 'GET'}
|
||||||
|
|
||||||
|
def test_logged_in_get_succeeds(self):
|
||||||
|
user = User.objects.create_user('user', 'user@example.com', 'password')
|
||||||
|
request = factory.get('/')
|
||||||
|
del user.is_active
|
||||||
|
request.user = user
|
||||||
|
response = async_to_sync(self.view)(request)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
assert response.data == {'method': 'GET'}
|
||||||
|
|
||||||
|
def test_post_succeeds(self):
|
||||||
|
request = factory.post('/', {'test': 'foo'})
|
||||||
|
response = async_to_sync(self.view)(request)
|
||||||
|
expected = {
|
||||||
|
'method': 'POST',
|
||||||
|
'data': {'test': ['foo']}
|
||||||
|
}
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
assert response.data == expected
|
||||||
|
|
||||||
|
def test_options_succeeds(self):
|
||||||
|
request = factory.options('/')
|
||||||
|
response = async_to_sync(self.view)(request)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
|
def test_400_parse_error(self):
|
||||||
|
request = factory.post('/', 'f00bar', content_type='application/json')
|
||||||
|
response = async_to_sync(self.view)(request)
|
||||||
|
expected = {
|
||||||
|
'detail': JSON_ERROR
|
||||||
|
}
|
||||||
|
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||||
|
assert sanitise_json_error(response.data) == expected
|
||||||
|
|
||||||
|
|
||||||
class FunctionBasedViewIntegrationTests(TestCase):
|
class FunctionBasedViewIntegrationTests(TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.view = basic_view
|
self.view = basic_view
|
||||||
|
|
||||||
|
def test_get_succeeds(self):
|
||||||
|
request = factory.get('/')
|
||||||
|
response = self.view(request)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
assert response.data == {'method': 'GET'}
|
||||||
|
|
||||||
|
def test_logged_in_get_succeeds(self):
|
||||||
|
user = User.objects.create_user('user', 'user@example.com', 'password')
|
||||||
|
request = factory.get('/')
|
||||||
|
del user.is_active
|
||||||
|
request.user = user
|
||||||
|
response = self.view(request)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
assert response.data == {'method': 'GET'}
|
||||||
|
|
||||||
|
def test_post_succeeds(self):
|
||||||
|
request = factory.post('/', {'test': 'foo'})
|
||||||
|
response = self.view(request)
|
||||||
|
expected = {
|
||||||
|
'method': 'POST',
|
||||||
|
'data': {'test': ['foo']}
|
||||||
|
}
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
assert response.data == expected
|
||||||
|
|
||||||
|
def test_options_succeeds(self):
|
||||||
|
request = factory.options('/')
|
||||||
|
response = self.view(request)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
def test_400_parse_error(self):
|
def test_400_parse_error(self):
|
||||||
request = factory.post('/', 'f00bar', content_type='application/json')
|
request = factory.post('/', 'f00bar', content_type='application/json')
|
||||||
response = self.view(request)
|
response = self.view(request)
|
||||||
|
@ -96,6 +228,54 @@ class FunctionBasedViewIntegrationTests(TestCase):
|
||||||
assert sanitise_json_error(response.data) == expected
|
assert sanitise_json_error(response.data) == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
django.VERSION < (4, 1),
|
||||||
|
reason="Async view support requires Django 4.1 or higher",
|
||||||
|
)
|
||||||
|
class FunctionBasedAsyncViewIntegrationTests(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.view = basic_async_view
|
||||||
|
|
||||||
|
def test_get_succeeds(self):
|
||||||
|
request = factory.get('/')
|
||||||
|
response = async_to_sync(self.view)(request)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
assert response.data == {'method': 'GET'}
|
||||||
|
|
||||||
|
def test_logged_in_get_succeeds(self):
|
||||||
|
user = User.objects.create_user('user', 'user@example.com', 'password')
|
||||||
|
request = factory.get('/')
|
||||||
|
del user.is_active
|
||||||
|
request.user = user
|
||||||
|
response = async_to_sync(self.view)(request)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
assert response.data == {'method': 'GET'}
|
||||||
|
|
||||||
|
def test_post_succeeds(self):
|
||||||
|
request = factory.post('/', {'test': 'foo'})
|
||||||
|
response = async_to_sync(self.view)(request)
|
||||||
|
expected = {
|
||||||
|
'method': 'POST',
|
||||||
|
'data': {'test': ['foo']}
|
||||||
|
}
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
assert response.data == expected
|
||||||
|
|
||||||
|
def test_options_succeeds(self):
|
||||||
|
request = factory.options('/')
|
||||||
|
response = async_to_sync(self.view)(request)
|
||||||
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
|
def test_400_parse_error(self):
|
||||||
|
request = factory.post('/', 'f00bar', content_type='application/json')
|
||||||
|
response = async_to_sync(self.view)(request)
|
||||||
|
expected = {
|
||||||
|
'detail': JSON_ERROR
|
||||||
|
}
|
||||||
|
assert response.status_code == status.HTTP_400_BAD_REQUEST
|
||||||
|
assert sanitise_json_error(response.data) == expected
|
||||||
|
|
||||||
|
|
||||||
class TestCustomExceptionHandler(TestCase):
|
class TestCustomExceptionHandler(TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.DEFAULT_HANDLER = api_settings.EXCEPTION_HANDLER
|
self.DEFAULT_HANDLER = api_settings.EXCEPTION_HANDLER
|
||||||
|
|
Loading…
Reference in New Issue
Block a user