From dd7adc7ed7ccd3100b41b4e5773ffc6d12590074 Mon Sep 17 00:00:00 2001 From: mizvyt Date: Wed, 18 Sep 2019 12:35:59 +0800 Subject: [PATCH] Additional throttle rate configurability --- docs/api-guide/throttling.md | 2 +- rest_framework/throttling.py | 24 ++++++++++++++++--- tests/test_throttling.py | 45 ++++++++++++++++++++++++++++++++++++ 3 files changed, 67 insertions(+), 4 deletions(-) diff --git a/docs/api-guide/throttling.md b/docs/api-guide/throttling.md index 215c735bf..54a0b6ca4 100644 --- a/docs/api-guide/throttling.md +++ b/docs/api-guide/throttling.md @@ -41,7 +41,7 @@ The default throttling policy may be set globally, using the `DEFAULT_THROTTLE_C } } -The rate descriptions used in `DEFAULT_THROTTLE_RATES` may include `second`, `minute`, `hour` or `day` as the throttle period. +The rate descriptions used in `DEFAULT_THROTTLE_RATES` may include `second`, `minute`, `hour` or `day` as the throttle period. To set the rate to a fraction of a period, simply prepend the desired timespan. For example, a rate of `'100/30s'` will mean "limit requests to a maximum of 100 per every 30 seconds". You can also set the throttling policy on a per-view or per-viewset basis, using the `APIView` class-based views. diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py index 0ba2ba66b..ca26549b8 100644 --- a/rest_framework/throttling.py +++ b/rest_framework/throttling.py @@ -1,6 +1,7 @@ """ Provides various throttling policies. """ +import re import time from django.core.cache import cache as default_cache @@ -101,9 +102,26 @@ class SimpleRateThrottle(BaseThrottle): """ if rate is None: return (None, None) - num, period = rate.split('/') - num_requests = int(num) - duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]] + + try: + num, period = rate.split('/') + num_requests = int(num) + # Get rate multiplier value if available + period_mult, _ = re.split('[s|m|h|d]', period, maxsplit=1) + period_char = re.findall('[s|m|h|d]', period)[0] + except ValueError: + msg = "Incorrect throttle rate set for '%s' scope" % self.scope + raise ImproperlyConfigured(msg) + + try: + period_mult = int(period_mult) + except ValueError: + period_mult = 1 + + duration = {'s': 1 * period_mult, + 'm': 60 * period_mult, + 'h': 3600 * period_mult, + 'd': 86400 * period_mult}[period_char] return (num_requests, duration) def allow_request(self, request, view): diff --git a/tests/test_throttling.py b/tests/test_throttling.py index d5a61232d..d3614fefc 100644 --- a/tests/test_throttling.py +++ b/tests/test_throttling.py @@ -451,6 +451,9 @@ class SimpleRateThrottleTests(TestCase): def setUp(self): SimpleRateThrottle.scope = 'anon' + def tearDown(self): + SimpleRateThrottle.rate = None + def test_get_rate_raises_error_if_scope_is_missing(self): throttle = SimpleRateThrottle() with pytest.raises(ImproperlyConfigured): @@ -462,6 +465,48 @@ class SimpleRateThrottleTests(TestCase): with pytest.raises(ImproperlyConfigured): SimpleRateThrottle() + def test_throttle_raises_error_if_rate_is_incorrect(self): + SimpleRateThrottle.rate = 'rate' + with pytest.raises(ImproperlyConfigured): + SimpleRateThrottle() + + SimpleRateThrottle.rate = 'rate/hour' + with pytest.raises(ImproperlyConfigured): + SimpleRateThrottle() + + SimpleRateThrottle.rate = '100/century' + with pytest.raises(ImproperlyConfigured): + SimpleRateThrottle() + + SimpleRateThrottle.rate = '100/10century' + with pytest.raises(ImproperlyConfigured): + SimpleRateThrottle() + + def test_parse_rate_returns_correct_rate(self): + rate_str = '10/h' + SimpleRateThrottle.rate = rate_str + rate = SimpleRateThrottle().parse_rate(rate_str) + assert rate == (10, 3600) + + rate_str = '30/hour' + SimpleRateThrottle.rate = rate_str + rate = SimpleRateThrottle().parse_rate(rate_str) + assert rate == (30, 3600) + + rate_str = '30/10min' + SimpleRateThrottle.rate = rate_str + rate = SimpleRateThrottle().parse_rate(rate_str) + assert rate == (30, 10 * 60) + + rate_str = '100/30seconds' + SimpleRateThrottle.rate = rate_str + rate = SimpleRateThrottle().parse_rate(rate_str) + assert rate == (100, 30) + + SimpleRateThrottle.rate = '100/10d' + rate = SimpleRateThrottle().parse_rate('100/10d') + assert rate == (100, 10 * 86400) + def test_parse_rate_returns_tuple_with_none_if_rate_not_provided(self): rate = SimpleRateThrottle().parse_rate(None) assert rate == (None, None)