mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-10-24 20:51:19 +03:00
Refactored throttling
This commit is contained in:
parent
8457c87196
commit
c28b719333
|
@ -49,8 +49,14 @@ class UnsupportedMediaType(APIException):
|
||||||
|
|
||||||
class Throttled(APIException):
|
class Throttled(APIException):
|
||||||
status_code = status.HTTP_429_TOO_MANY_REQUESTS
|
status_code = status.HTTP_429_TOO_MANY_REQUESTS
|
||||||
default_detail = "Request was throttled. Expected available in %d seconds."
|
default_detail = "Request was throttled."
|
||||||
|
extra_detail = "Expected available in %d second%s."
|
||||||
|
|
||||||
def __init__(self, wait, detail=None):
|
def __init__(self, wait=None, detail=None):
|
||||||
import math
|
import math
|
||||||
self.detail = (detail or self.default_detail) % int(math.ceil(wait))
|
self.wait = wait and math.ceil(wait) or None
|
||||||
|
if wait is not None:
|
||||||
|
format = detail or self.default_detail + self.extra_detail
|
||||||
|
self.detail = format % (self.wait, self.wait != 1 and 's' or '')
|
||||||
|
else:
|
||||||
|
self.detail = detail or self.default_detail
|
||||||
|
|
|
@ -81,7 +81,7 @@ class BaseParser(object):
|
||||||
Should return parsed data, or a DataAndFiles object consisting of the
|
Should return parsed data, or a DataAndFiles object consisting of the
|
||||||
parsed data and files.
|
parsed data and files.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError(".parse_stream() Must be overridden to be implemented.")
|
raise NotImplementedError(".parse_stream() must be overridden.")
|
||||||
|
|
||||||
|
|
||||||
class JSONParser(BaseParser):
|
class JSONParser(BaseParser):
|
||||||
|
|
|
@ -5,10 +5,6 @@ for checking if a request passes a certain set of constraints.
|
||||||
Permission behavior is provided by mixing the :class:`mixins.PermissionsMixin` class into a :class:`View` class.
|
Permission behavior is provided by mixing the :class:`mixins.PermissionsMixin` class into a :class:`View` class.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from django.core.cache import cache
|
|
||||||
from djangorestframework.exceptions import PermissionDenied, Throttled
|
|
||||||
import time
|
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
'BasePermission',
|
'BasePermission',
|
||||||
'FullAnonAccess',
|
'FullAnonAccess',
|
||||||
|
@ -32,20 +28,11 @@ class BasePermission(object):
|
||||||
"""
|
"""
|
||||||
self.view = view
|
self.view = view
|
||||||
|
|
||||||
def check_permission(self, auth):
|
def check_permission(self, request, obj=None):
|
||||||
"""
|
"""
|
||||||
Should simply return, or raise an :exc:`response.ImmediateResponse`.
|
Should simply return, or raise an :exc:`response.ImmediateResponse`.
|
||||||
"""
|
"""
|
||||||
pass
|
raise NotImplementedError(".check_permission() must be overridden.")
|
||||||
|
|
||||||
|
|
||||||
class FullAnonAccess(BasePermission):
|
|
||||||
"""
|
|
||||||
Allows full access.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def check_permission(self, user):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class IsAuthenticated(BasePermission):
|
class IsAuthenticated(BasePermission):
|
||||||
|
@ -53,9 +40,10 @@ class IsAuthenticated(BasePermission):
|
||||||
Allows access only to authenticated users.
|
Allows access only to authenticated users.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def check_permission(self, user):
|
def check_permission(self, request, obj=None):
|
||||||
if not user.is_authenticated():
|
if request.user.is_authenticated():
|
||||||
raise PermissionDenied()
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class IsAdminUser(BasePermission):
|
class IsAdminUser(BasePermission):
|
||||||
|
@ -63,20 +51,22 @@ class IsAdminUser(BasePermission):
|
||||||
Allows access only to admin users.
|
Allows access only to admin users.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def check_permission(self, user):
|
def check_permission(self, request, obj=None):
|
||||||
if not user.is_staff:
|
if request.user.is_staff:
|
||||||
raise PermissionDenied()
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class IsUserOrIsAnonReadOnly(BasePermission):
|
class IsAuthenticatedOrReadOnly(BasePermission):
|
||||||
"""
|
"""
|
||||||
The request is authenticated as a user, or is a read-only request.
|
The request is authenticated as a user, or is a read-only request.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def check_permission(self, user):
|
def check_permission(self, request, obj=None):
|
||||||
if (not user.is_authenticated() and
|
if (request.user.is_authenticated() or
|
||||||
self.view.method not in SAFE_METHODS):
|
request.method in SAFE_METHODS):
|
||||||
raise PermissionDenied()
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class DjangoModelPermissions(BasePermission):
|
class DjangoModelPermissions(BasePermission):
|
||||||
|
@ -114,128 +104,10 @@ class DjangoModelPermissions(BasePermission):
|
||||||
}
|
}
|
||||||
return [perm % kwargs for perm in self.perms_map[method]]
|
return [perm % kwargs for perm in self.perms_map[method]]
|
||||||
|
|
||||||
def check_permission(self, user):
|
def check_permission(self, request, obj=None):
|
||||||
method = self.view.method
|
model_cls = self.view.model
|
||||||
model_cls = self.view.resource.model
|
perms = self.get_required_permissions(request.method, model_cls)
|
||||||
perms = self.get_required_permissions(method, model_cls)
|
|
||||||
|
|
||||||
if not user.is_authenticated or not user.has_perms(perms):
|
if request.user.is_authenticated() and request.user.has_perms(perms, obj):
|
||||||
raise PermissionDenied()
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
class BaseThrottle(BasePermission):
|
|
||||||
"""
|
|
||||||
Rate throttling of requests.
|
|
||||||
|
|
||||||
The rate (requests / seconds) is set by a :attr:`throttle` attribute
|
|
||||||
on the :class:`.View` class. The attribute is a string of the form 'number of
|
|
||||||
requests/period'.
|
|
||||||
|
|
||||||
Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day')
|
|
||||||
|
|
||||||
Previous request information used for throttling is stored in the cache.
|
|
||||||
"""
|
|
||||||
|
|
||||||
attr_name = 'throttle'
|
|
||||||
default = '0/sec'
|
|
||||||
timer = time.time
|
|
||||||
|
|
||||||
def get_cache_key(self):
|
|
||||||
"""
|
|
||||||
Should return a unique cache-key which can be used for throttling.
|
|
||||||
Must be overridden.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def check_permission(self, auth):
|
|
||||||
"""
|
|
||||||
Check the throttling.
|
|
||||||
Return `None` or raise an :exc:`.ImmediateResponse`.
|
|
||||||
"""
|
|
||||||
num, period = getattr(self.view, self.attr_name, self.default).split('/')
|
|
||||||
self.num_requests = int(num)
|
|
||||||
self.duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
|
|
||||||
self.auth = auth
|
|
||||||
self.check_throttle()
|
|
||||||
|
|
||||||
def check_throttle(self):
|
|
||||||
"""
|
|
||||||
Implement the check to see if the request should be throttled.
|
|
||||||
|
|
||||||
On success calls :meth:`throttle_success`.
|
|
||||||
On failure calls :meth:`throttle_failure`.
|
|
||||||
"""
|
|
||||||
self.key = self.get_cache_key()
|
|
||||||
self.history = cache.get(self.key, [])
|
|
||||||
self.now = self.timer()
|
|
||||||
|
|
||||||
# Drop any requests from the history which have now passed the
|
|
||||||
# throttle duration
|
|
||||||
while self.history and self.history[-1] <= self.now - self.duration:
|
|
||||||
self.history.pop()
|
|
||||||
if len(self.history) >= self.num_requests:
|
|
||||||
self.throttle_failure()
|
|
||||||
else:
|
|
||||||
self.throttle_success()
|
|
||||||
|
|
||||||
def throttle_success(self):
|
|
||||||
"""
|
|
||||||
Inserts the current request's timestamp along with the key
|
|
||||||
into the cache.
|
|
||||||
"""
|
|
||||||
self.history.insert(0, self.now)
|
|
||||||
cache.set(self.key, self.history, self.duration)
|
|
||||||
header = 'status=SUCCESS; next=%.2f sec' % self.next()
|
|
||||||
self.view.headers['X-Throttle'] = header
|
|
||||||
|
|
||||||
def throttle_failure(self):
|
|
||||||
"""
|
|
||||||
Called when a request to the API has failed due to throttling.
|
|
||||||
Raises a '503 service unavailable' response.
|
|
||||||
"""
|
|
||||||
wait = self.next()
|
|
||||||
header = 'status=FAILURE; next=%.2f sec' % wait
|
|
||||||
self.view.headers['X-Throttle'] = header
|
|
||||||
raise Throttled(wait)
|
|
||||||
|
|
||||||
def next(self):
|
|
||||||
"""
|
|
||||||
Returns the recommended next request time in seconds.
|
|
||||||
"""
|
|
||||||
if self.history:
|
|
||||||
remaining_duration = self.duration - (self.now - self.history[-1])
|
|
||||||
else:
|
|
||||||
remaining_duration = self.duration
|
|
||||||
|
|
||||||
available_requests = self.num_requests - len(self.history) + 1
|
|
||||||
|
|
||||||
return remaining_duration / float(available_requests)
|
|
||||||
|
|
||||||
|
|
||||||
class PerUserThrottling(BaseThrottle):
|
|
||||||
"""
|
|
||||||
Limits the rate of API calls that may be made by a given user.
|
|
||||||
|
|
||||||
The user id will be used as a unique identifier if the user is
|
|
||||||
authenticated. For anonymous requests, the IP address of the client will
|
|
||||||
be used.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_cache_key(self):
|
|
||||||
if self.auth.is_authenticated():
|
|
||||||
ident = self.auth.id
|
|
||||||
else:
|
|
||||||
ident = self.view.request.META.get('REMOTE_ADDR', None)
|
|
||||||
return 'throttle_user_%s' % ident
|
|
||||||
|
|
||||||
|
|
||||||
class PerViewThrottling(BaseThrottle):
|
|
||||||
"""
|
|
||||||
Limits the rate of API calls that may be used on a given view.
|
|
||||||
|
|
||||||
The class name of the view is used as a unique identifier to
|
|
||||||
throttle against.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_cache_key(self):
|
|
||||||
return 'throttle_view_%s' % self.view.__class__.__name__
|
|
||||||
|
|
|
@ -8,24 +8,24 @@ from django.core.cache import cache
|
||||||
|
|
||||||
from djangorestframework.compat import RequestFactory
|
from djangorestframework.compat import RequestFactory
|
||||||
from djangorestframework.views import APIView
|
from djangorestframework.views import APIView
|
||||||
from djangorestframework.permissions import PerUserThrottling, PerViewThrottling
|
from djangorestframework.throttling import PerUserThrottling, PerViewThrottling
|
||||||
from djangorestframework.response import Response
|
from djangorestframework.response import Response
|
||||||
|
|
||||||
|
|
||||||
class MockView(APIView):
|
class MockView(APIView):
|
||||||
permission_classes = (PerUserThrottling,)
|
throttle_classes = (PerUserThrottling,)
|
||||||
throttle = '3/sec'
|
rate = '3/sec'
|
||||||
|
|
||||||
def get(self, request):
|
def get(self, request):
|
||||||
return Response('foo')
|
return Response('foo')
|
||||||
|
|
||||||
|
|
||||||
class MockView_PerViewThrottling(MockView):
|
class MockView_PerViewThrottling(MockView):
|
||||||
permission_classes = (PerViewThrottling,)
|
throttle_classes = (PerViewThrottling,)
|
||||||
|
|
||||||
|
|
||||||
class MockView_MinuteThrottling(MockView):
|
class MockView_MinuteThrottling(MockView):
|
||||||
throttle = '3/min'
|
rate = '3/min'
|
||||||
|
|
||||||
|
|
||||||
class ThrottlingTests(TestCase):
|
class ThrottlingTests(TestCase):
|
||||||
|
@ -51,7 +51,7 @@ class ThrottlingTests(TestCase):
|
||||||
"""
|
"""
|
||||||
Explicitly set the timer, overriding time.time()
|
Explicitly set the timer, overriding time.time()
|
||||||
"""
|
"""
|
||||||
view.permission_classes[0].timer = lambda self: value
|
view.throttle_classes[0].timer = lambda self: value
|
||||||
|
|
||||||
def test_request_throttling_expires(self):
|
def test_request_throttling_expires(self):
|
||||||
"""
|
"""
|
||||||
|
@ -101,17 +101,20 @@ class ThrottlingTests(TestCase):
|
||||||
for timer, expect in expected_headers:
|
for timer, expect in expected_headers:
|
||||||
self.set_throttle_timer(view, timer)
|
self.set_throttle_timer(view, timer)
|
||||||
response = view.as_view()(request)
|
response = view.as_view()(request)
|
||||||
self.assertEquals(response['X-Throttle'], expect)
|
if expect is not None:
|
||||||
|
self.assertEquals(response['X-Throttle-Wait-Seconds'], expect)
|
||||||
|
else:
|
||||||
|
self.assertFalse('X-Throttle-Wait-Seconds' in response.headers)
|
||||||
|
|
||||||
def test_seconds_fields(self):
|
def test_seconds_fields(self):
|
||||||
"""
|
"""
|
||||||
Ensure for second based throttles.
|
Ensure for second based throttles.
|
||||||
"""
|
"""
|
||||||
self.ensure_response_header_contains_proper_throttle_field(MockView,
|
self.ensure_response_header_contains_proper_throttle_field(MockView,
|
||||||
((0, 'status=SUCCESS; next=0.33 sec'),
|
((0, None),
|
||||||
(0, 'status=SUCCESS; next=0.50 sec'),
|
(0, None),
|
||||||
(0, 'status=SUCCESS; next=1.00 sec'),
|
(0, None),
|
||||||
(0, 'status=FAILURE; next=1.00 sec')
|
(0, '1')
|
||||||
))
|
))
|
||||||
|
|
||||||
def test_minutes_fields(self):
|
def test_minutes_fields(self):
|
||||||
|
@ -119,10 +122,10 @@ class ThrottlingTests(TestCase):
|
||||||
Ensure for minute based throttles.
|
Ensure for minute based throttles.
|
||||||
"""
|
"""
|
||||||
self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling,
|
self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling,
|
||||||
((0, 'status=SUCCESS; next=20.00 sec'),
|
((0, None),
|
||||||
(0, 'status=SUCCESS; next=30.00 sec'),
|
(0, None),
|
||||||
(0, 'status=SUCCESS; next=60.00 sec'),
|
(0, None),
|
||||||
(0, 'status=FAILURE; next=60.00 sec')
|
(0, '60')
|
||||||
))
|
))
|
||||||
|
|
||||||
def test_next_rate_remains_constant_if_followed(self):
|
def test_next_rate_remains_constant_if_followed(self):
|
||||||
|
@ -131,9 +134,9 @@ class ThrottlingTests(TestCase):
|
||||||
the throttling rate should stay constant.
|
the throttling rate should stay constant.
|
||||||
"""
|
"""
|
||||||
self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling,
|
self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling,
|
||||||
((0, 'status=SUCCESS; next=20.00 sec'),
|
((0, None),
|
||||||
(20, 'status=SUCCESS; next=20.00 sec'),
|
(20, None),
|
||||||
(40, 'status=SUCCESS; next=20.00 sec'),
|
(40, None),
|
||||||
(60, 'status=SUCCESS; next=20.00 sec'),
|
(60, None),
|
||||||
(80, 'status=SUCCESS; next=20.00 sec')
|
(80, None)
|
||||||
))
|
))
|
||||||
|
|
139
djangorestframework/throttling.py
Normal file
139
djangorestframework/throttling.py
Normal file
|
@ -0,0 +1,139 @@
|
||||||
|
from django.core.cache import cache
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
class BaseThrottle(object):
|
||||||
|
"""
|
||||||
|
Rate throttling of requests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, view=None):
|
||||||
|
"""
|
||||||
|
All throttles hold a reference to the instantiating view.
|
||||||
|
"""
|
||||||
|
self.view = view
|
||||||
|
|
||||||
|
def check_throttle(self, request):
|
||||||
|
"""
|
||||||
|
Return `True` if the request should be allowed, `False` otherwise.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError('.check_throttle() must be overridden')
|
||||||
|
|
||||||
|
def wait(self):
|
||||||
|
"""
|
||||||
|
Optionally, return a recommeded number of seconds to wait before
|
||||||
|
the next request.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleCachingThrottle(BaseThrottle):
|
||||||
|
"""
|
||||||
|
A simple cache implementation, that only requires `.get_cache_key()`
|
||||||
|
to be overridden.
|
||||||
|
|
||||||
|
The rate (requests / seconds) is set by a :attr:`throttle` attribute
|
||||||
|
on the :class:`.View` class. The attribute is a string of the form 'number of
|
||||||
|
requests/period'.
|
||||||
|
|
||||||
|
Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day')
|
||||||
|
|
||||||
|
Previous request information used for throttling is stored in the cache.
|
||||||
|
"""
|
||||||
|
|
||||||
|
attr_name = 'rate'
|
||||||
|
rate = '1000/day'
|
||||||
|
timer = time.time
|
||||||
|
|
||||||
|
def __init__(self, view):
|
||||||
|
"""
|
||||||
|
Check the throttling.
|
||||||
|
Return `None` or raise an :exc:`.ImmediateResponse`.
|
||||||
|
"""
|
||||||
|
super(SimpleCachingThrottle, self).__init__(view)
|
||||||
|
num, period = getattr(view, self.attr_name, self.rate).split('/')
|
||||||
|
self.num_requests = int(num)
|
||||||
|
self.duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
|
||||||
|
|
||||||
|
def get_cache_key(self, request):
|
||||||
|
"""
|
||||||
|
Should return a unique cache-key which can be used for throttling.
|
||||||
|
Must be overridden.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError('.get_cache_key() must be overridden')
|
||||||
|
|
||||||
|
def check_throttle(self, request):
|
||||||
|
"""
|
||||||
|
Implement the check to see if the request should be throttled.
|
||||||
|
|
||||||
|
On success calls :meth:`throttle_success`.
|
||||||
|
On failure calls :meth:`throttle_failure`.
|
||||||
|
"""
|
||||||
|
self.key = self.get_cache_key(request)
|
||||||
|
self.history = cache.get(self.key, [])
|
||||||
|
self.now = self.timer()
|
||||||
|
|
||||||
|
# Drop any requests from the history which have now passed the
|
||||||
|
# throttle duration
|
||||||
|
while self.history and self.history[-1] <= self.now - self.duration:
|
||||||
|
self.history.pop()
|
||||||
|
if len(self.history) >= self.num_requests:
|
||||||
|
return self.throttle_failure()
|
||||||
|
return self.throttle_success()
|
||||||
|
|
||||||
|
def throttle_success(self):
|
||||||
|
"""
|
||||||
|
Inserts the current request's timestamp along with the key
|
||||||
|
into the cache.
|
||||||
|
"""
|
||||||
|
self.history.insert(0, self.now)
|
||||||
|
cache.set(self.key, self.history, self.duration)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def throttle_failure(self):
|
||||||
|
"""
|
||||||
|
Called when a request to the API has failed due to throttling.
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
def wait(self):
|
||||||
|
"""
|
||||||
|
Returns the recommended next request time in seconds.
|
||||||
|
"""
|
||||||
|
if self.history:
|
||||||
|
remaining_duration = self.duration - (self.now - self.history[-1])
|
||||||
|
else:
|
||||||
|
remaining_duration = self.duration
|
||||||
|
|
||||||
|
available_requests = self.num_requests - len(self.history) + 1
|
||||||
|
|
||||||
|
return remaining_duration / float(available_requests)
|
||||||
|
|
||||||
|
|
||||||
|
class PerUserThrottling(SimpleCachingThrottle):
|
||||||
|
"""
|
||||||
|
Limits the rate of API calls that may be made by a given user.
|
||||||
|
|
||||||
|
The user id will be used as a unique identifier if the user is
|
||||||
|
authenticated. For anonymous requests, the IP address of the client will
|
||||||
|
be used.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_cache_key(self, request):
|
||||||
|
if request.user.is_authenticated():
|
||||||
|
ident = request.user.id
|
||||||
|
else:
|
||||||
|
ident = request.META.get('REMOTE_ADDR', None)
|
||||||
|
return 'throttle_user_%s' % ident
|
||||||
|
|
||||||
|
|
||||||
|
class PerViewThrottling(SimpleCachingThrottle):
|
||||||
|
"""
|
||||||
|
Limits the rate of API calls that may be used on a given view.
|
||||||
|
|
||||||
|
The class name of the view is used as a unique identifier to
|
||||||
|
throttle against.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_cache_key(self, request):
|
||||||
|
return 'throttle_view_%s' % self.view.__class__.__name__
|
|
@ -18,7 +18,7 @@ from djangorestframework.compat import View as _View, apply_markdown
|
||||||
from djangorestframework.response import Response
|
from djangorestframework.response import Response
|
||||||
from djangorestframework.request import Request
|
from djangorestframework.request import Request
|
||||||
from djangorestframework.settings import api_settings
|
from djangorestframework.settings import api_settings
|
||||||
from djangorestframework import parsers, authentication, permissions, status, exceptions, mixins
|
from djangorestframework import parsers, authentication, status, exceptions, mixins
|
||||||
|
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
|
@ -86,7 +86,12 @@ class APIView(_View):
|
||||||
List of all authenticating methods to attempt.
|
List of all authenticating methods to attempt.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
permission_classes = (permissions.FullAnonAccess,)
|
throttle_classes = ()
|
||||||
|
"""
|
||||||
|
List of all throttles to check.
|
||||||
|
"""
|
||||||
|
|
||||||
|
permission_classes = ()
|
||||||
"""
|
"""
|
||||||
List of all permissions that must be checked.
|
List of all permissions that must be checked.
|
||||||
"""
|
"""
|
||||||
|
@ -195,12 +200,27 @@ class APIView(_View):
|
||||||
"""
|
"""
|
||||||
return [permission(self) for permission in self.permission_classes]
|
return [permission(self) for permission in self.permission_classes]
|
||||||
|
|
||||||
def check_permissions(self, user):
|
def get_throttles(self):
|
||||||
"""
|
"""
|
||||||
Check user permissions and either raise an ``ImmediateResponse`` or return.
|
Instantiates and returns the list of thottles that this view requires.
|
||||||
|
"""
|
||||||
|
return [throttle(self) for throttle in self.throttle_classes]
|
||||||
|
|
||||||
|
def check_permissions(self, request, obj=None):
|
||||||
|
"""
|
||||||
|
Check user permissions and either raise an ``PermissionDenied`` or return.
|
||||||
"""
|
"""
|
||||||
for permission in self.get_permissions():
|
for permission in self.get_permissions():
|
||||||
permission.check_permission(user)
|
if not permission.check_permission(request, obj):
|
||||||
|
raise exceptions.PermissionDenied()
|
||||||
|
|
||||||
|
def check_throttles(self, request):
|
||||||
|
"""
|
||||||
|
Check throttles and either raise a `Throttled` exception or return.
|
||||||
|
"""
|
||||||
|
for throttle in self.get_throttles():
|
||||||
|
if not throttle.check_throttle(request):
|
||||||
|
raise exceptions.Throttled(throttle.wait())
|
||||||
|
|
||||||
def initial(self, request, *args, **kargs):
|
def initial(self, request, *args, **kargs):
|
||||||
"""
|
"""
|
||||||
|
@ -232,6 +252,9 @@ class APIView(_View):
|
||||||
Handle any exception that occurs, by returning an appropriate response,
|
Handle any exception that occurs, by returning an appropriate response,
|
||||||
or re-raising the error.
|
or re-raising the error.
|
||||||
"""
|
"""
|
||||||
|
if isinstance(exc, exceptions.Throttled):
|
||||||
|
self.headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait
|
||||||
|
|
||||||
if isinstance(exc, exceptions.APIException):
|
if isinstance(exc, exceptions.APIException):
|
||||||
return Response({'detail': exc.detail}, status=exc.status_code)
|
return Response({'detail': exc.detail}, status=exc.status_code)
|
||||||
elif isinstance(exc, Http404):
|
elif isinstance(exc, Http404):
|
||||||
|
@ -255,8 +278,9 @@ class APIView(_View):
|
||||||
try:
|
try:
|
||||||
self.initial(request, *args, **kwargs)
|
self.initial(request, *args, **kwargs)
|
||||||
|
|
||||||
# check that user has the relevant permissions
|
# Check that the request is allowed
|
||||||
self.check_permissions(request.user)
|
self.check_permissions(request)
|
||||||
|
self.check_throttles(request)
|
||||||
|
|
||||||
# Get the appropriate handler method
|
# Get the appropriate handler method
|
||||||
if request.method.lower() in self.http_method_names:
|
if request.method.lower() in self.http_method_names:
|
||||||
|
@ -283,11 +307,12 @@ class BaseView(APIView):
|
||||||
serializer_class = None
|
serializer_class = None
|
||||||
|
|
||||||
def get_serializer(self, data=None, files=None, instance=None):
|
def get_serializer(self, data=None, files=None, instance=None):
|
||||||
|
# TODO: add support for files
|
||||||
context = {
|
context = {
|
||||||
'request': self.request,
|
'request': self.request,
|
||||||
'format': self.kwargs.get('format', None)
|
'format': self.kwargs.get('format', None)
|
||||||
}
|
}
|
||||||
return self.serializer_class(data, context=context)
|
return self.serializer_class(data, instance=instance, context=context)
|
||||||
|
|
||||||
|
|
||||||
class MultipleObjectBaseView(MultipleObjectMixin, BaseView):
|
class MultipleObjectBaseView(MultipleObjectMixin, BaseView):
|
||||||
|
@ -301,7 +326,13 @@ class SingleObjectBaseView(SingleObjectMixin, BaseView):
|
||||||
"""
|
"""
|
||||||
Base class for generic views onto a model instance.
|
Base class for generic views onto a model instance.
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
def get_object(self):
|
||||||
|
"""
|
||||||
|
Override default to add support for object-level permissions.
|
||||||
|
"""
|
||||||
|
super(self, SingleObjectBaseView).get_object()
|
||||||
|
self.check_permissions(self.request, self.object)
|
||||||
|
|
||||||
|
|
||||||
# Concrete view classes that provide method handlers
|
# Concrete view classes that provide method handlers
|
||||||
|
|
Loading…
Reference in New Issue
Block a user