Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

142

143

144

145

146

147

148

149

150

151

152

153

154

155

156

157

158

159

160

161

162

163

164

165

166

167

168

169

170

171

172

173

174

175

176

177

178

179

180

181

182

183

184

185

186

187

188

189

190

191

192

193

194

195

196

197

198

199

200

201

202

203

204

205

206

207

208

209

210

211

212

213

214

215

216

217

218

219

220

221

222

223

224

225

226

""" 

Provides various throttling policies. 

""" 

from __future__ import unicode_literals 

from django.core.cache import cache 

from django.core.exceptions import ImproperlyConfigured 

from rest_framework.settings import api_settings 

import time 

 

 

class BaseThrottle(object): 

    """ 

    Rate throttling of requests. 

    """ 

    def allow_request(self, request, view): 

        """ 

        Return `True` if the request should be allowed, `False` otherwise. 

        """ 

        raise NotImplementedError('.allow_request() must be overridden') 

 

    def wait(self): 

        """ 

        Optionally, return a recommended number of seconds to wait before 

        the next request. 

        """ 

        return None 

 

 

class SimpleRateThrottle(BaseThrottle): 

    """ 

    A simple cache implementation, that only requires `.get_cache_key()` 

    to be overridden. 

 

    The rate (requests / seconds) is set by a `throttle` attribute on the View 

    class.  The attribute is a string of the form 'number_of_requests/period'. 

 

    Period should be one of: ('s', 'sec', 'm', 'min', 'h', 'hour', 'd', 'day') 

 

    Previous request information used for throttling is stored in the cache. 

    """ 

 

    timer = time.time 

    cache_format = 'throtte_%(scope)s_%(ident)s' 

    scope = None 

    THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES 

 

    def __init__(self): 

        if not getattr(self, 'rate', None): 

            self.rate = self.get_rate() 

        self.num_requests, self.duration = self.parse_rate(self.rate) 

 

    def get_cache_key(self, request, view): 

        """ 

        Should return a unique cache-key which can be used for throttling. 

        Must be overridden. 

 

        May return `None` if the request should not be throttled. 

        """ 

        raise NotImplementedError('.get_cache_key() must be overridden') 

 

    def get_rate(self): 

        """ 

        Determine the string representation of the allowed request rate. 

        """ 

        if not getattr(self, 'scope', None): 

            msg = ("You must set either `.scope` or `.rate` for '%s' throttle" % 

                   self.__class__.__name__) 

            raise ImproperlyConfigured(msg) 

 

        try: 

            return self.THROTTLE_RATES[self.scope] 

        except KeyError: 

            msg = "No default throttle rate set for '%s' scope" % self.scope 

            raise ImproperlyConfigured(msg) 

 

    def parse_rate(self, rate): 

        """ 

        Given the request rate string, return a two tuple of: 

        <allowed number of requests>, <period of time in seconds> 

        """ 

        if rate is None: 

            return (None, None) 

        num, period = rate.split('/') 

        num_requests = int(num) 

        duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]] 

        return (num_requests, duration) 

 

    def allow_request(self, request, view): 

        """ 

        Implement the check to see if the request should be throttled. 

 

        On success calls `throttle_success`. 

        On failure calls `throttle_failure`. 

        """ 

        if self.rate is None: 

            return True 

 

        self.key = self.get_cache_key(request, view) 

        self.history = cache.get(self.key, []) 

        self.now = self.timer() 

 

        # Drop any requests from the history which have now passed the 

        # throttle duration 

        while self.history and self.history[-1] <= self.now - self.duration: 

            self.history.pop() 

        if len(self.history) >= self.num_requests: 

            return self.throttle_failure() 

        return self.throttle_success() 

 

    def throttle_success(self): 

        """ 

        Inserts the current request's timestamp along with the key 

        into the cache. 

        """ 

        self.history.insert(0, self.now) 

        cache.set(self.key, self.history, self.duration) 

        return True 

 

    def throttle_failure(self): 

        """ 

        Called when a request to the API has failed due to throttling. 

        """ 

        return False 

 

    def wait(self): 

        """ 

        Returns the recommended next request time in seconds. 

        """ 

        if self.history: 

            remaining_duration = self.duration - (self.now - self.history[-1]) 

        else: 

            remaining_duration = self.duration 

 

        available_requests = self.num_requests - len(self.history) + 1 

 

        return remaining_duration / float(available_requests) 

 

 

class AnonRateThrottle(SimpleRateThrottle): 

    """ 

    Limits the rate of API calls that may be made by a anonymous users. 

 

    The IP address of the request will be used as the unique cache key. 

    """ 

    scope = 'anon' 

 

    def get_cache_key(self, request, view): 

        if request.user.is_authenticated(): 

            return None  # Only throttle unauthenticated requests. 

 

        ident = request.META.get('REMOTE_ADDR', None) 

 

        return self.cache_format % { 

            'scope': self.scope, 

            'ident': ident 

        } 

 

 

class UserRateThrottle(SimpleRateThrottle): 

    """ 

    Limits the rate of API calls that may be made by a given user. 

 

    The user id will be used as a unique cache key if the user is 

    authenticated.  For anonymous requests, the IP address of the request will 

    be used. 

    """ 

    scope = 'user' 

 

    def get_cache_key(self, request, view): 

        if request.user.is_authenticated(): 

            ident = request.user.id 

        else: 

            ident = request.META.get('REMOTE_ADDR', None) 

 

        return self.cache_format % { 

            'scope': self.scope, 

            'ident': ident 

        } 

 

 

class ScopedRateThrottle(SimpleRateThrottle): 

    """ 

    Limits the rate of API calls by different amounts for various parts of 

    the API.  Any view that has the `throttle_scope` property set will be 

    throttled.  The unique cache key will be generated by concatenating the 

    user id of the request, and the scope of the view being accessed. 

    """ 

    scope_attr = 'throttle_scope' 

 

    def __init__(self): 

        # Override the usual SimpleRateThrottle, because we can't determine 

        # the rate until called by the view. 

        pass 

 

    def allow_request(self, request, view): 

        # We can only determine the scope once we're called by the view. 

        self.scope = getattr(view, self.scope_attr, None) 

 

        # If a view does not have a `throttle_scope` always allow the request 

        if not self.scope: 

            return True 

 

        # Determine the allowed request rate as we normally would during 

        # the `__init__` call. 

        self.rate = self.get_rate() 

        self.num_requests, self.duration = self.parse_rate(self.rate) 

 

        # We can now proceed as normal. 

        return super(ScopedRateThrottle, self).allow_request(request, view) 

 

    def get_cache_key(self, request, view): 

        """ 

        If `view.throttle_scope` is not set, don't apply this throttle. 

 

        Otherwise generate the unique cache key by concatenating the user id 

        with the '.throttle_scope` property of the view. 

        """ 

        if request.user.is_authenticated(): 

            ident = request.user.id 

        else: 

            ident = request.META.get('REMOTE_ADDR', None) 

 

        return self.cache_format % { 

            'scope': self.scope, 

            'ident': ident 

        }