diff --git a/djangorestframework/exceptions.py b/djangorestframework/exceptions.py index 51c5dbb71..0b4dacf73 100644 --- a/djangorestframework/exceptions.py +++ b/djangorestframework/exceptions.py @@ -49,8 +49,14 @@ class UnsupportedMediaType(APIException): class Throttled(APIException): status_code = status.HTTP_429_TOO_MANY_REQUESTS - default_detail = "Request was throttled. Expected available in %d seconds." + default_detail = "Request was throttled." + extra_detail = "Expected available in %d second%s." - def __init__(self, wait, detail=None): + def __init__(self, wait=None, detail=None): import math - self.detail = (detail or self.default_detail) % int(math.ceil(wait)) + self.wait = wait and math.ceil(wait) or None + if wait is not None: + format = detail or self.default_detail + self.extra_detail + self.detail = format % (self.wait, self.wait != 1 and 's' or '') + else: + self.detail = detail or self.default_detail diff --git a/djangorestframework/parsers.py b/djangorestframework/parsers.py index 96dd81ede..fb08c5a0b 100644 --- a/djangorestframework/parsers.py +++ b/djangorestframework/parsers.py @@ -81,7 +81,7 @@ class BaseParser(object): Should return parsed data, or a DataAndFiles object consisting of the parsed data and files. """ - raise NotImplementedError(".parse_stream() Must be overridden to be implemented.") + raise NotImplementedError(".parse_stream() must be overridden.") class JSONParser(BaseParser): diff --git a/djangorestframework/permissions.py b/djangorestframework/permissions.py index bdda4defa..d6405a361 100644 --- a/djangorestframework/permissions.py +++ b/djangorestframework/permissions.py @@ -5,10 +5,6 @@ for checking if a request passes a certain set of constraints. Permission behavior is provided by mixing the :class:`mixins.PermissionsMixin` class into a :class:`View` class. """ -from django.core.cache import cache -from djangorestframework.exceptions import PermissionDenied, Throttled -import time - __all__ = ( 'BasePermission', 'FullAnonAccess', @@ -32,20 +28,11 @@ class BasePermission(object): """ self.view = view - def check_permission(self, auth): + def check_permission(self, request, obj=None): """ Should simply return, or raise an :exc:`response.ImmediateResponse`. """ - pass - - -class FullAnonAccess(BasePermission): - """ - Allows full access. - """ - - def check_permission(self, user): - pass + raise NotImplementedError(".check_permission() must be overridden.") class IsAuthenticated(BasePermission): @@ -53,9 +40,10 @@ class IsAuthenticated(BasePermission): Allows access only to authenticated users. """ - def check_permission(self, user): - if not user.is_authenticated(): - raise PermissionDenied() + def check_permission(self, request, obj=None): + if request.user.is_authenticated(): + return True + return False class IsAdminUser(BasePermission): @@ -63,20 +51,22 @@ class IsAdminUser(BasePermission): Allows access only to admin users. """ - def check_permission(self, user): - if not user.is_staff: - raise PermissionDenied() + def check_permission(self, request, obj=None): + if request.user.is_staff: + return True + return False -class IsUserOrIsAnonReadOnly(BasePermission): +class IsAuthenticatedOrReadOnly(BasePermission): """ The request is authenticated as a user, or is a read-only request. """ - def check_permission(self, user): - if (not user.is_authenticated() and - self.view.method not in SAFE_METHODS): - raise PermissionDenied() + def check_permission(self, request, obj=None): + if (request.user.is_authenticated() or + request.method in SAFE_METHODS): + return True + return False class DjangoModelPermissions(BasePermission): @@ -114,128 +104,10 @@ class DjangoModelPermissions(BasePermission): } return [perm % kwargs for perm in self.perms_map[method]] - def check_permission(self, user): - method = self.view.method - model_cls = self.view.resource.model - perms = self.get_required_permissions(method, model_cls) + def check_permission(self, request, obj=None): + model_cls = self.view.model + perms = self.get_required_permissions(request.method, model_cls) - if not user.is_authenticated or not user.has_perms(perms): - raise PermissionDenied() - - -class BaseThrottle(BasePermission): - """ - Rate throttling of requests. - - The rate (requests / seconds) is set by a :attr:`throttle` attribute - on the :class:`.View` class. The attribute is a string of the form 'number of - requests/period'. - - Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day') - - Previous request information used for throttling is stored in the cache. - """ - - attr_name = 'throttle' - default = '0/sec' - timer = time.time - - def get_cache_key(self): - """ - Should return a unique cache-key which can be used for throttling. - Must be overridden. - """ - pass - - def check_permission(self, auth): - """ - Check the throttling. - Return `None` or raise an :exc:`.ImmediateResponse`. - """ - num, period = getattr(self.view, self.attr_name, self.default).split('/') - self.num_requests = int(num) - self.duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]] - self.auth = auth - self.check_throttle() - - def check_throttle(self): - """ - Implement the check to see if the request should be throttled. - - On success calls :meth:`throttle_success`. - On failure calls :meth:`throttle_failure`. - """ - self.key = self.get_cache_key() - self.history = cache.get(self.key, []) - self.now = self.timer() - - # Drop any requests from the history which have now passed the - # throttle duration - while self.history and self.history[-1] <= self.now - self.duration: - self.history.pop() - if len(self.history) >= self.num_requests: - self.throttle_failure() - else: - self.throttle_success() - - def throttle_success(self): - """ - Inserts the current request's timestamp along with the key - into the cache. - """ - self.history.insert(0, self.now) - cache.set(self.key, self.history, self.duration) - header = 'status=SUCCESS; next=%.2f sec' % self.next() - self.view.headers['X-Throttle'] = header - - def throttle_failure(self): - """ - Called when a request to the API has failed due to throttling. - Raises a '503 service unavailable' response. - """ - wait = self.next() - header = 'status=FAILURE; next=%.2f sec' % wait - self.view.headers['X-Throttle'] = header - raise Throttled(wait) - - def next(self): - """ - Returns the recommended next request time in seconds. - """ - if self.history: - remaining_duration = self.duration - (self.now - self.history[-1]) - else: - remaining_duration = self.duration - - available_requests = self.num_requests - len(self.history) + 1 - - return remaining_duration / float(available_requests) - - -class PerUserThrottling(BaseThrottle): - """ - Limits the rate of API calls that may be made by a given user. - - The user id will be used as a unique identifier if the user is - authenticated. For anonymous requests, the IP address of the client will - be used. - """ - - def get_cache_key(self): - if self.auth.is_authenticated(): - ident = self.auth.id - else: - ident = self.view.request.META.get('REMOTE_ADDR', None) - return 'throttle_user_%s' % ident - - -class PerViewThrottling(BaseThrottle): - """ - Limits the rate of API calls that may be used on a given view. - - The class name of the view is used as a unique identifier to - throttle against. - """ - - def get_cache_key(self): - return 'throttle_view_%s' % self.view.__class__.__name__ + if request.user.is_authenticated() and request.user.has_perms(perms, obj): + return True + return False diff --git a/djangorestframework/tests/throttling.py b/djangorestframework/tests/throttling.py index a8e446e85..d144d9568 100644 --- a/djangorestframework/tests/throttling.py +++ b/djangorestframework/tests/throttling.py @@ -8,24 +8,24 @@ from django.core.cache import cache from djangorestframework.compat import RequestFactory from djangorestframework.views import APIView -from djangorestframework.permissions import PerUserThrottling, PerViewThrottling +from djangorestframework.throttling import PerUserThrottling, PerViewThrottling from djangorestframework.response import Response class MockView(APIView): - permission_classes = (PerUserThrottling,) - throttle = '3/sec' + throttle_classes = (PerUserThrottling,) + rate = '3/sec' def get(self, request): return Response('foo') class MockView_PerViewThrottling(MockView): - permission_classes = (PerViewThrottling,) + throttle_classes = (PerViewThrottling,) class MockView_MinuteThrottling(MockView): - throttle = '3/min' + rate = '3/min' class ThrottlingTests(TestCase): @@ -51,7 +51,7 @@ class ThrottlingTests(TestCase): """ Explicitly set the timer, overriding time.time() """ - view.permission_classes[0].timer = lambda self: value + view.throttle_classes[0].timer = lambda self: value def test_request_throttling_expires(self): """ @@ -101,17 +101,20 @@ class ThrottlingTests(TestCase): for timer, expect in expected_headers: self.set_throttle_timer(view, timer) response = view.as_view()(request) - self.assertEquals(response['X-Throttle'], expect) + if expect is not None: + self.assertEquals(response['X-Throttle-Wait-Seconds'], expect) + else: + self.assertFalse('X-Throttle-Wait-Seconds' in response.headers) def test_seconds_fields(self): """ Ensure for second based throttles. """ self.ensure_response_header_contains_proper_throttle_field(MockView, - ((0, 'status=SUCCESS; next=0.33 sec'), - (0, 'status=SUCCESS; next=0.50 sec'), - (0, 'status=SUCCESS; next=1.00 sec'), - (0, 'status=FAILURE; next=1.00 sec') + ((0, None), + (0, None), + (0, None), + (0, '1') )) def test_minutes_fields(self): @@ -119,10 +122,10 @@ class ThrottlingTests(TestCase): Ensure for minute based throttles. """ self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling, - ((0, 'status=SUCCESS; next=20.00 sec'), - (0, 'status=SUCCESS; next=30.00 sec'), - (0, 'status=SUCCESS; next=60.00 sec'), - (0, 'status=FAILURE; next=60.00 sec') + ((0, None), + (0, None), + (0, None), + (0, '60') )) def test_next_rate_remains_constant_if_followed(self): @@ -131,9 +134,9 @@ class ThrottlingTests(TestCase): the throttling rate should stay constant. """ self.ensure_response_header_contains_proper_throttle_field(MockView_MinuteThrottling, - ((0, 'status=SUCCESS; next=20.00 sec'), - (20, 'status=SUCCESS; next=20.00 sec'), - (40, 'status=SUCCESS; next=20.00 sec'), - (60, 'status=SUCCESS; next=20.00 sec'), - (80, 'status=SUCCESS; next=20.00 sec') + ((0, None), + (20, None), + (40, None), + (60, None), + (80, None) )) diff --git a/djangorestframework/throttling.py b/djangorestframework/throttling.py new file mode 100644 index 000000000..a096eab79 --- /dev/null +++ b/djangorestframework/throttling.py @@ -0,0 +1,139 @@ +from django.core.cache import cache +import time + + +class BaseThrottle(object): + """ + Rate throttling of requests. + """ + + def __init__(self, view=None): + """ + All throttles hold a reference to the instantiating view. + """ + self.view = view + + def check_throttle(self, request): + """ + Return `True` if the request should be allowed, `False` otherwise. + """ + raise NotImplementedError('.check_throttle() must be overridden') + + def wait(self): + """ + Optionally, return a recommeded number of seconds to wait before + the next request. + """ + return None + + +class SimpleCachingThrottle(BaseThrottle): + """ + A simple cache implementation, that only requires `.get_cache_key()` + to be overridden. + + The rate (requests / seconds) is set by a :attr:`throttle` attribute + on the :class:`.View` class. The attribute is a string of the form 'number of + requests/period'. + + Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day') + + Previous request information used for throttling is stored in the cache. + """ + + attr_name = 'rate' + rate = '1000/day' + timer = time.time + + def __init__(self, view): + """ + Check the throttling. + Return `None` or raise an :exc:`.ImmediateResponse`. + """ + super(SimpleCachingThrottle, self).__init__(view) + num, period = getattr(view, self.attr_name, self.rate).split('/') + self.num_requests = int(num) + self.duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]] + + def get_cache_key(self, request): + """ + Should return a unique cache-key which can be used for throttling. + Must be overridden. + """ + raise NotImplementedError('.get_cache_key() must be overridden') + + def check_throttle(self, request): + """ + Implement the check to see if the request should be throttled. + + On success calls :meth:`throttle_success`. + On failure calls :meth:`throttle_failure`. + """ + self.key = self.get_cache_key(request) + self.history = cache.get(self.key, []) + self.now = self.timer() + + # Drop any requests from the history which have now passed the + # throttle duration + while self.history and self.history[-1] <= self.now - self.duration: + self.history.pop() + if len(self.history) >= self.num_requests: + return self.throttle_failure() + return self.throttle_success() + + def throttle_success(self): + """ + Inserts the current request's timestamp along with the key + into the cache. + """ + self.history.insert(0, self.now) + cache.set(self.key, self.history, self.duration) + return True + + def throttle_failure(self): + """ + Called when a request to the API has failed due to throttling. + """ + return False + + def wait(self): + """ + Returns the recommended next request time in seconds. + """ + if self.history: + remaining_duration = self.duration - (self.now - self.history[-1]) + else: + remaining_duration = self.duration + + available_requests = self.num_requests - len(self.history) + 1 + + return remaining_duration / float(available_requests) + + +class PerUserThrottling(SimpleCachingThrottle): + """ + Limits the rate of API calls that may be made by a given user. + + The user id will be used as a unique identifier if the user is + authenticated. For anonymous requests, the IP address of the client will + be used. + """ + + def get_cache_key(self, request): + if request.user.is_authenticated(): + ident = request.user.id + else: + ident = request.META.get('REMOTE_ADDR', None) + return 'throttle_user_%s' % ident + + +class PerViewThrottling(SimpleCachingThrottle): + """ + Limits the rate of API calls that may be used on a given view. + + The class name of the view is used as a unique identifier to + throttle against. + """ + + def get_cache_key(self, request): + return 'throttle_view_%s' % self.view.__class__.__name__ diff --git a/djangorestframework/views.py b/djangorestframework/views.py index 3f0138d83..9796b3629 100644 --- a/djangorestframework/views.py +++ b/djangorestframework/views.py @@ -18,7 +18,7 @@ from djangorestframework.compat import View as _View, apply_markdown from djangorestframework.response import Response from djangorestframework.request import Request from djangorestframework.settings import api_settings -from djangorestframework import parsers, authentication, permissions, status, exceptions, mixins +from djangorestframework import parsers, authentication, status, exceptions, mixins __all__ = ( @@ -86,7 +86,12 @@ class APIView(_View): List of all authenticating methods to attempt. """ - permission_classes = (permissions.FullAnonAccess,) + throttle_classes = () + """ + List of all throttles to check. + """ + + permission_classes = () """ List of all permissions that must be checked. """ @@ -195,12 +200,27 @@ class APIView(_View): """ return [permission(self) for permission in self.permission_classes] - def check_permissions(self, user): + def get_throttles(self): """ - Check user permissions and either raise an ``ImmediateResponse`` or return. + Instantiates and returns the list of thottles that this view requires. + """ + return [throttle(self) for throttle in self.throttle_classes] + + def check_permissions(self, request, obj=None): + """ + Check user permissions and either raise an ``PermissionDenied`` or return. """ for permission in self.get_permissions(): - permission.check_permission(user) + if not permission.check_permission(request, obj): + raise exceptions.PermissionDenied() + + def check_throttles(self, request): + """ + Check throttles and either raise a `Throttled` exception or return. + """ + for throttle in self.get_throttles(): + if not throttle.check_throttle(request): + raise exceptions.Throttled(throttle.wait()) def initial(self, request, *args, **kargs): """ @@ -232,6 +252,9 @@ class APIView(_View): Handle any exception that occurs, by returning an appropriate response, or re-raising the error. """ + if isinstance(exc, exceptions.Throttled): + self.headers['X-Throttle-Wait-Seconds'] = '%d' % exc.wait + if isinstance(exc, exceptions.APIException): return Response({'detail': exc.detail}, status=exc.status_code) elif isinstance(exc, Http404): @@ -255,8 +278,9 @@ class APIView(_View): try: self.initial(request, *args, **kwargs) - # check that user has the relevant permissions - self.check_permissions(request.user) + # Check that the request is allowed + self.check_permissions(request) + self.check_throttles(request) # Get the appropriate handler method if request.method.lower() in self.http_method_names: @@ -283,11 +307,12 @@ class BaseView(APIView): serializer_class = None def get_serializer(self, data=None, files=None, instance=None): + # TODO: add support for files context = { 'request': self.request, 'format': self.kwargs.get('format', None) } - return self.serializer_class(data, context=context) + return self.serializer_class(data, instance=instance, context=context) class MultipleObjectBaseView(MultipleObjectMixin, BaseView): @@ -301,7 +326,13 @@ class SingleObjectBaseView(SingleObjectMixin, BaseView): """ Base class for generic views onto a model instance. """ - pass + + def get_object(self): + """ + Override default to add support for object-level permissions. + """ + super(self, SingleObjectBaseView).get_object() + self.check_permissions(self.request, self.object) # Concrete view classes that provide method handlers