diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py index 37e71e93e..c0a9ed6fc 100644 --- a/rest_framework/throttling.py +++ b/rest_framework/throttling.py @@ -2,6 +2,7 @@ Provides various throttling policies. """ import time +import re from django.core.cache import cache as default_cache from django.core.exceptions import ImproperlyConfigured @@ -9,6 +10,7 @@ from django.core.exceptions import ImproperlyConfigured from rest_framework.settings import api_settings + class BaseThrottle: """ Rate throttling of requests. @@ -66,6 +68,7 @@ class SimpleRateThrottle(BaseThrottle): cache_format = 'throttle_%(scope)s_%(ident)s' scope = None THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES + _RATE_DENOMINATOR_REGEX= re.compile("^(\d*)(\D+)$") def __init__(self): if not getattr(self, 'rate', None): @@ -105,10 +108,16 @@ class SimpleRateThrottle(BaseThrottle): return (None, None) num, period = rate.split('/') num_requests = int(num) - denominator_num = period[:-1] - denominator_num = int(denominator_num) if len(denominator_num) > 0 else 1 - duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[-1]] - duration *= denominator_num + + denominator = period[:-1] + m = re.search(self._RATE_DENOMINATOR_REGEX, denominator) + dg = m.groups() + + dn= int(m[0]) if len(m[0])>0 else 1 + du = m[1][0] + + duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[du] + duration *= dn return (num_requests, duration) def allow_request(self, request, view):