Permissions and throttles no longer have a view attribute on self. Explicitly passed to .has_permissions(request, view, obj=None) / .allow_request(request, view)

This commit is contained in:
Tom Christie 2012-10-10 10:02:37 +01:00
parent 900c4b625b
commit ccd2b0117d
6 changed files with 28 additions and 46 deletions

View File

@ -88,7 +88,7 @@ The `DjangoModelPermissions` class also supports object-level permissions. Thir
## Custom permissions ## Custom permissions
To implement a custom permission, override `BasePermission` and implement the `.has_permission(self, request, obj=None)` method. To implement a custom permission, override `BasePermission` and implement the `.has_permission(self, request, view, obj=None)` method.
The method should return `True` if the request should be granted access, and `False` otherwise. The method should return `True` if the request should be granted access, and `False` otherwise.

View File

@ -144,8 +144,8 @@ User requests to either `ContactListView` or `ContactDetailView` would be restri
## Custom throttles ## Custom throttles
To create a custom throttle, override `BaseThrottle` and implement `.allow_request(request)`. The method should return `True` if the request should be allowed, and `False` otherwise. To create a custom throttle, override `BaseThrottle` and implement `.allow_request(request, view)`. The method should return `True` if the request should be allowed, and `False` otherwise.
Optionally you may also override the `.wait()` method. If implemented, `.wait()` should return a recomended number of seconds to wait before attempting the next request, or `None`. The `.wait()` method will only be called if `.check_throttle()` has previously returned `False`. Optionally you may also override the `.wait()` method. If implemented, `.wait()` should return a recomended number of seconds to wait before attempting the next request, or `None`. The `.wait()` method will only be called if `.allow_request()` has previously returned `False`.
[permissions]: permissions.md [permissions]: permissions.md

View File

@ -13,13 +13,8 @@ class BasePermission(object):
""" """
A base class from which all permission classes should inherit. A base class from which all permission classes should inherit.
""" """
def __init__(self, view):
"""
Permission classes are always passed the current view on creation.
"""
self.view = view
def has_permission(self, request, obj=None): def has_permission(self, request, view, obj=None):
""" """
Should simply return, or raise an :exc:`response.ImmediateResponse`. Should simply return, or raise an :exc:`response.ImmediateResponse`.
""" """
@ -31,7 +26,7 @@ class IsAuthenticated(BasePermission):
Allows access only to authenticated users. Allows access only to authenticated users.
""" """
def has_permission(self, request, obj=None): def has_permission(self, request, view, obj=None):
if request.user and request.user.is_authenticated(): if request.user and request.user.is_authenticated():
return True return True
return False return False
@ -42,7 +37,7 @@ class IsAdminUser(BasePermission):
Allows access only to admin users. Allows access only to admin users.
""" """
def has_permission(self, request, obj=None): def has_permission(self, request, view, obj=None):
if request.user and request.user.is_staff: if request.user and request.user.is_staff:
return True return True
return False return False
@ -53,7 +48,7 @@ class IsAuthenticatedOrReadOnly(BasePermission):
The request is authenticated as a user, or is a read-only request. The request is authenticated as a user, or is a read-only request.
""" """
def has_permission(self, request, obj=None): def has_permission(self, request, view, obj=None):
if (request.method in SAFE_METHODS or if (request.method in SAFE_METHODS or
request.user and request.user and
request.user.is_authenticated()): request.user.is_authenticated()):
@ -96,8 +91,8 @@ class DjangoModelPermissions(BasePermission):
} }
return [perm % kwargs for perm in self.perms_map[method]] return [perm % kwargs for perm in self.perms_map[method]]
def has_permission(self, request, obj=None): def has_permission(self, request, view, obj=None):
model_cls = self.view.model model_cls = view.model
perms = self.get_required_permissions(request.method, model_cls) perms = self.get_required_permissions(request.method, model_cls)
if (request.user and if (request.user and

View File

@ -92,7 +92,7 @@ urlpatterns = patterns('',
class POSTDeniedPermission(permissions.BasePermission): class POSTDeniedPermission(permissions.BasePermission):
def has_permission(self, request, obj=None): def has_permission(self, request, view, obj=None):
return request.method != 'POST' return request.method != 'POST'

View File

@ -8,14 +8,7 @@ class BaseThrottle(object):
""" """
Rate throttling of requests. Rate throttling of requests.
""" """
def allow_request(self, request, view):
def __init__(self, view=None):
"""
All throttles hold a reference to the instantiating view.
"""
self.view = view
def allow_request(self, request):
""" """
Return `True` if the request should be allowed, `False` otherwise. Return `True` if the request should be allowed, `False` otherwise.
""" """
@ -48,13 +41,12 @@ class SimpleRateThottle(BaseThrottle):
cache_format = 'throtte_%(scope)s_%(ident)s' cache_format = 'throtte_%(scope)s_%(ident)s'
scope = None scope = None
def __init__(self, view): def __init__(self):
super(SimpleRateThottle, self).__init__(view)
if not getattr(self, 'rate', None): if not getattr(self, 'rate', None):
self.rate = self.get_rate() self.rate = self.get_rate()
self.num_requests, self.duration = self.parse_rate(self.rate) self.num_requests, self.duration = self.parse_rate(self.rate)
def get_cache_key(self, request): def get_cache_key(self, request, view):
""" """
Should return a unique cache-key which can be used for throttling. Should return a unique cache-key which can be used for throttling.
Must be overridden. Must be overridden.
@ -90,7 +82,7 @@ class SimpleRateThottle(BaseThrottle):
duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]] duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
return (num_requests, duration) return (num_requests, duration)
def allow_request(self, request): def allow_request(self, request, view):
""" """
Implement the check to see if the request should be throttled. Implement the check to see if the request should be throttled.
@ -100,7 +92,7 @@ class SimpleRateThottle(BaseThrottle):
if self.rate is None: if self.rate is None:
return True return True
self.key = self.get_cache_key(request) self.key = self.get_cache_key(request, view)
self.history = cache.get(self.key, []) self.history = cache.get(self.key, [])
self.now = self.timer() self.now = self.timer()
@ -149,7 +141,7 @@ class AnonRateThrottle(SimpleRateThottle):
""" """
scope = 'anon' scope = 'anon'
def get_cache_key(self, request): def get_cache_key(self, request, view):
if request.user.is_authenticated(): if request.user.is_authenticated():
return None # Only throttle unauthenticated requests. return None # Only throttle unauthenticated requests.
@ -171,7 +163,7 @@ class UserRateThrottle(SimpleRateThottle):
""" """
scope = 'user' scope = 'user'
def get_cache_key(self, request): def get_cache_key(self, request, view):
if request.user.is_authenticated(): if request.user.is_authenticated():
ident = request.user.id ident = request.user.id
else: else:
@ -190,25 +182,20 @@ class ScopedRateThrottle(SimpleRateThottle):
throttled. The unique cache key will be generated by concatenating the throttled. The unique cache key will be generated by concatenating the
user id of the request, and the scope of the view being accessed. user id of the request, and the scope of the view being accessed.
""" """
scope_attr = 'throttle_scope' scope_attr = 'throttle_scope'
def __init__(self, view): def get_cache_key(self, request, view):
"""
Scope is determined from the view being accessed.
"""
self.scope = getattr(self.view, self.scope_attr, None)
super(ScopedRateThrottle, self).__init__(view)
def get_cache_key(self, request):
""" """
If `view.throttle_scope` is not set, don't apply this throttle. If `view.throttle_scope` is not set, don't apply this throttle.
Otherwise generate the unique cache key by concatenating the user id Otherwise generate the unique cache key by concatenating the user id
with the '.throttle_scope` property of the view. with the '.throttle_scope` property of the view.
""" """
if not self.scope: scope = getattr(view, self.scope_attr, None)
return None # Only throttle views if `.throttle_scope` is set.
if not scope:
# Only throttle views if `.throttle_scope` is set on the view.
return None
if request.user.is_authenticated(): if request.user.is_authenticated():
ident = request.user.id ident = request.user.id
@ -216,6 +203,6 @@ class ScopedRateThrottle(SimpleRateThottle):
ident = request.META.get('REMOTE_ADDR', None) ident = request.META.get('REMOTE_ADDR', None)
return self.cache_format % { return self.cache_format % {
'scope': self.scope, 'scope': scope,
'ident': ident 'ident': ident
} }

View File

@ -189,13 +189,13 @@ class APIView(View):
""" """
Instantiates and returns the list of permissions that this view requires. Instantiates and returns the list of permissions that this view requires.
""" """
return [permission(self) for permission in self.permission_classes] return [permission() for permission in self.permission_classes]
def get_throttles(self): def get_throttles(self):
""" """
Instantiates and returns the list of thottles that this view uses. Instantiates and returns the list of thottles that this view uses.
""" """
return [throttle(self) for throttle in self.throttle_classes] return [throttle() for throttle in self.throttle_classes]
def get_content_negotiator(self): def get_content_negotiator(self):
""" """
@ -220,7 +220,7 @@ class APIView(View):
Return `True` if the request should be permitted. Return `True` if the request should be permitted.
""" """
for permission in self.get_permissions(): for permission in self.get_permissions():
if not permission.has_permission(request, obj): if not permission.has_permission(request, self, obj):
return False return False
return True return True
@ -229,7 +229,7 @@ class APIView(View):
Check if request should be throttled. Check if request should be throttled.
""" """
for throttle in self.get_throttles(): for throttle in self.get_throttles():
if not throttle.allow_request(request): if not throttle.allow_request(request, self):
self.throttled(request, throttle.wait()) self.throttled(request, throttle.wait())
# Dispatch methods # Dispatch methods